pax_global_header00006660000000000000000000000064151306607060014516gustar00rootroot0000000000000052 comment=7659dd8e0fa06b41290ad29af323d93d673c6b36 quic-go-0.59.0/000077500000000000000000000000001513066070600131555ustar00rootroot00000000000000quic-go-0.59.0/.clusterfuzzlite/000077500000000000000000000000001513066070600165115ustar00rootroot00000000000000quic-go-0.59.0/.clusterfuzzlite/Dockerfile000066400000000000000000000011131513066070600204770ustar00rootroot00000000000000FROM gcr.io/oss-fuzz-base/base-builder-go:v1 ARG TARGETPLATFORM RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}" ENV GOVERSION=1.25.0 RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ filename="go${GOVERSION}.${platform}.tar.gz" && \ wget https://dl.google.com/go/${filename} && \ mkdir temp-go && \ rm -rf /root/.go/* && \ tar -C temp-go/ -xzf ${filename} && \ mv temp-go/go/* /root/.go/ && \ rm -r ${filename} temp-go RUN apt-get update && apt-get install -y make autoconf automake libtool COPY . $SRC/quic-go WORKDIR quic-go COPY .clusterfuzzlite/build.sh $SRC/ quic-go-0.59.0/.clusterfuzzlite/build.sh000077500000000000000000000007551513066070600201560ustar00rootroot00000000000000#!/bin/bash -eu export CXX="${CXX} -lresolv" # required by Go 1.20 compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/frames Fuzz frame_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/header Fuzz header_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/transportparameters Fuzz transportparameter_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/tokens Fuzz token_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/handshake Fuzz handshake_fuzzer quic-go-0.59.0/.clusterfuzzlite/project.yaml000066400000000000000000000000151513066070600210370ustar00rootroot00000000000000language: go quic-go-0.59.0/.githooks/000077500000000000000000000000001513066070600150625ustar00rootroot00000000000000quic-go-0.59.0/.githooks/README.md000066400000000000000000000002311513066070600163350ustar00rootroot00000000000000# Git Hooks This directory contains useful Git hooks for working with quic-go. Install them by running ```bash git config core.hooksPath .githooks ``` quic-go-0.59.0/.githooks/pre-commit000077500000000000000000000016221513066070600170650ustar00rootroot00000000000000#!/bin/bash # Check that test files don't contain focussed test cases. errored=false for f in $(git diff --diff-filter=d --cached --name-only); do if [[ $f != *_test.go ]]; then continue; fi output=$(git show :"$f" | grep -n -e "FIt(" -e "FContext(" -e "FDescribe(") if [ $? -eq 0 ]; then echo "$f contains a focussed test:" echo "$output" echo "" errored=true fi done pushd ./integrationtests/gomodvendor > /dev/null go mod tidy if [[ -n $(git diff --diff-filter=d --name-only -- "go.mod" "go.sum") ]]; then echo "go.mod / go.sum in integrationtests/gomodvendor not tidied" errored=true fi popd > /dev/null # Check that all Go files are properly gofumpt-ed. output=$(gofumpt -d $(git diff --diff-filter=d --cached --name-only -- '*.go')) if [ -n "$output" ]; then echo "Found files that are not properly gofumpt-ed." echo "$output" errored=true fi if [ "$errored" = true ]; then exit 1 fi quic-go-0.59.0/.github/000077500000000000000000000000001513066070600145155ustar00rootroot00000000000000quic-go-0.59.0/.github/FUNDING.yml000066400000000000000000000014641513066070600163370ustar00rootroot00000000000000# These are supported funding model platforms github: [marten-seemann] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] quic-go-0.59.0/.github/dependabot.yml000066400000000000000000000001661513066070600173500ustar00rootroot00000000000000version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" quic-go-0.59.0/.github/workflows/000077500000000000000000000000001513066070600165525ustar00rootroot00000000000000quic-go-0.59.0/.github/workflows/build-interop-docker.yml000066400000000000000000000027731513066070600233300ustar00rootroot00000000000000name: Build interop Docker image on: push: branches: - master tags: - 'v*' pull_request: concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ github.event_name == 'push' }} jobs: interop: runs-on: ${{ fromJSON(vars['DOCKER_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} timeout-minutes: 30 steps: - uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: platforms: linux/amd64,linux/arm64 - name: Login to Docker Hub if: github.event_name == 'push' uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: set tag name id: tag # Tagged releases won't be picked up by the interop runner automatically, # but they can be useful when debugging regressions. run: | if [[ $GITHUB_REF == refs/tags/* ]]; then echo "tag=${GITHUB_REF#refs/tags/}" | tee -a $GITHUB_OUTPUT; else echo 'tag=latest' | tee -a $GITHUB_OUTPUT; fi - uses: docker/build-push-action@v6 with: context: "." file: "interop/Dockerfile" platforms: linux/amd64,linux/arm64 push: ${{ github.event_name == 'push' }} tags: martenseemann/quic-go-interop:${{ steps.tag.outputs.tag }} quic-go-0.59.0/.github/workflows/clusterfuzz-lite-pr.yml000066400000000000000000000035011513066070600232460ustar00rootroot00000000000000name: ClusterFuzzLite PR fuzzing on: pull_request: paths: - '**' permissions: read-all jobs: PR: runs-on: ${{ fromJSON(vars['CLUSTERFUZZ_LITE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} concurrency: group: ${{ github.workflow }}-${{ matrix.sanitizer }}-${{ github.ref }} cancel-in-progress: true strategy: fail-fast: false matrix: sanitizer: - address steps: - name: Build Fuzzers (${{ matrix.sanitizer }}) id: build uses: google/clusterfuzzlite/actions/build_fuzzers@v1 with: language: go github-token: ${{ secrets.GITHUB_TOKEN }} sanitizer: ${{ matrix.sanitizer }} # Optional but recommended: used to only run fuzzers that are affected # by the PR. # See later section on "Git repo for storage". # storage-repo: https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/OWNER/STORAGE-REPO-NAME.git # storage-repo-branch: main # Optional. Defaults to "main" # storage-repo-branch-coverage: gh-pages # Optional. Defaults to "gh-pages". - name: Run Fuzzers (${{ matrix.sanitizer }}) id: run uses: google/clusterfuzzlite/actions/run_fuzzers@v1 with: github-token: ${{ secrets.GITHUB_TOKEN }} fuzz-seconds: 480 mode: 'code-change' sanitizer: ${{ matrix.sanitizer }} output-sarif: true parallel-fuzzing: true # Optional but recommended: used to download the corpus produced by # batch fuzzing. # See later section on "Git repo for storage". # storage-repo: https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/OWNER/STORAGE-REPO-NAME.git # storage-repo-branch: main # Optional. Defaults to "main" # storage-repo-branch-coverage: gh-pages # Optional. Defaults to "gh-pages". quic-go-0.59.0/.github/workflows/cross-compile.sh000077500000000000000000000015021513066070600216660ustar00rootroot00000000000000#!/bin/bash set -e dist="$1" goos=$(echo "$dist" | cut -d "/" -f1) goarch=$(echo "$dist" | cut -d "/" -f2) # cross-compiling for android is a pain... if [[ "$goos" == "android" ]]; then exit; fi # iOS builds require Cgo, see https://github.com/golang/go/issues/43343 # Cgo would then need a C cross compilation setup. Not worth the hassle. if [[ "$goos" == "ios" ]]; then exit; fi # Write all log output to a temporary file instead of to stdout. # That allows running this script in parallel, while preserving the correct order of the output. log_file=$(mktemp) error_handler() { cat "$log_file" >&2 rm "$log_file" exit 1 } trap 'error_handler' ERR echo "$dist" >> "$log_file" out="main-$goos-$goarch" GOOS=$goos GOARCH=$goarch go build -o $out example/main.go >> "$log_file" 2>&1 rm $out cat "$log_file" rm "$log_file" quic-go-0.59.0/.github/workflows/cross-compile.yml000066400000000000000000000033121513066070600220530ustar00rootroot00000000000000on: [push, pull_request] jobs: crosscompile: strategy: fail-fast: false matrix: go: [ "1.24.x", "1.25.x", "1.26.0-rc.1" ] runs-on: ${{ fromJSON(vars['CROSS_COMPILE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} name: "Cross Compilation (Go ${{matrix.go}})" timeout-minutes: 30 steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - name: Get Date id: get-date run: echo "date=$(/bin/date -u "+%Y%m%d")" >> $GITHUB_OUTPUT - name: Load Go build cache id: load-go-cache uses: actions/cache/restore@v5 with: path: ~/.cache/go-build key: go-${{ matrix.go }}-crosscompile-${{ steps.get-date.outputs.date }} restore-keys: go-${{ matrix.go }}-crosscompile- - name: Install build utils run: | sudo apt-get update sudo apt-get install -y gcc-multilib - name: Install dependencies run: go build example/main.go - name: Run cross compilation # run in parallel on as many cores as are available on the machine run: go tool dist list | xargs -I % -P "$(nproc)" .github/workflows/cross-compile.sh % - name: Save Go build cache # only store cache when on master if: github.event_name == 'push' && github.ref_name == 'master' uses: actions/cache/save@v5 with: path: ~/.cache/go-build # Caches are immutable, so we only update it once per day (at most). # See https://github.com/actions/cache/blob/main/tips-and-workarounds.md#update-a-cache key: go-${{ matrix.go }}-crosscompile-${{ steps.get-date.outputs.date }} quic-go-0.59.0/.github/workflows/go-generate.sh000077500000000000000000000011141513066070600213030ustar00rootroot00000000000000#!/usr/bin/env bash set -e # delete all go-generated files (that adhere to the comment convention) git ls-files -z | grep --include \*.go -lrIZ "^// Code generated .* DO NOT EDIT\.$" | tr '\0' '\n' | xargs rm -f # First regenerate sys_conn_buffers_write.go. # If it doesn't exist, the following mockgen calls will fail. go generate -run "sys_conn_buffers_write.go" # now generate everything go generate ./... # Check if any files were changed git diff --exit-code || ( echo "Generated files are not up to date. Please run 'go generate ./...' and commit the changes." exit 1 ) quic-go-0.59.0/.github/workflows/integration.yml000066400000000000000000000103121513066070600216150ustar00rootroot00000000000000on: [push, pull_request] jobs: integration: strategy: fail-fast: false matrix: os: [ "ubuntu" ] go: [ "1.24.x", "1.25.x", "1.26.0-rc.1" ] race: [ false ] include: - os: "ubuntu" go: "1.25.x" race: true - os: "windows" go: "1.25.x" race: false - os: "macos" go: "1.25.x" race: false runs-on: ${{ fromJSON(vars[format('INTEGRATION_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }} timeout-minutes: 30 defaults: run: shell: bash # by default Windows uses PowerShell, which uses a different syntax for setting environment variables env: DEBUG: false # set this to true to export qlogs and save them as artifacts TIMESCALE_FACTOR: 3 GOEXPERIMENT: ${{ matrix.go == '1.24.x' && 'synctest' || '' }} name: "Integration (${{ matrix.os }}, Go ${{ matrix.go }}${{ matrix.race && ', race' || '' }})" steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - name: Install go-junit-report run: go install github.com/jstemmer/go-junit-report/v2@v2.1.0 - name: Set qlogger if: env.DEBUG == 'true' run: echo "QLOGFLAG= -qlog" >> $GITHUB_ENV - name: Enable race detector if: ${{ matrix.race }} run: echo "RACEFLAG= -race" >> $GITHUB_ENV - run: go version - name: Run tools tests run: go test ${{ env.RACEFLAG }} -v -timeout 30s -shuffle=on ./integrationtests/tools/... 2>&1 | go-junit-report -set-exit-code -iocopy -out report_tools.xml - name: Run version negotiation tests run: go test ${{ env.RACEFLAG }} -v -timeout 30s -shuffle=on ./integrationtests/versionnegotiation ${{ env.QLOGFLAG }} 2>&1 | go-junit-report -set-exit-code -iocopy -out report_versionnegotiation.xml - name: Run self tests, using QUIC v1 if: success() || failure() # run this step even if the previous one failed run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=1 ${{ env.QLOGFLAG }} 2>&1 | go-junit-report -set-exit-code -iocopy -out report_self.xml - name: Run self tests, using QUIC v2 if: ${{ !matrix.race && (success() || failure()) }} # run this step even if the previous one failed run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=2 ${{ env.QLOGFLAG }} 2>&1 | go-junit-report -set-exit-code -iocopy -out report_self_v2.xml - name: Run self tests, with GSO disabled if: ${{ matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed env: QUIC_GO_DISABLE_GSO: true run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=1 ${{ env.QLOGFLAG }} 2>&1 | go-junit-report -set-exit-code -iocopy -out report_self_nogso.xml - name: Run self tests, with ECN disabled if: ${{ !matrix.race && matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed env: QUIC_GO_DISABLE_ECN: true run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=1 ${{ env.QLOGFLAG }} 2>&1 | go-junit-report -set-exit-code -iocopy -out report_self_noecn.xml - name: Run benchmarks if: ${{ !matrix.race }} run: go test -v -run=^$ -timeout 5m -shuffle=on -bench=. ./integrationtests/self - name: save qlogs if: ${{ always() && env.DEBUG == 'true' }} uses: actions/upload-artifact@v6 with: name: qlogs-${{ matrix.os }}-go${{ matrix.go }}-race${{ matrix.race }} path: integrationtests/self/*.qlog retention-days: 7 - name: Upload report to Codecov if: ${{ !cancelled() && !matrix.race }} uses: codecov/codecov-action@v5 with: report_type: test_results name: Unit tests files: report_tools.xml,report_versionnegotiation.xml,report_self.xml,report_self_v2.xml,report_self_nogso.xml,report_self_noecn.xml env_vars: OS,GO token: ${{ secrets.CODECOV_TOKEN }} quic-go-0.59.0/.github/workflows/lint.yml000066400000000000000000000060511513066070600202450ustar00rootroot00000000000000on: [push, pull_request] jobs: check: runs-on: ubuntu-latest timeout-minutes: 15 env: GOEXPERIMENT: ${{ matrix.go == '1.24.x' && 'synctest' || '' }} steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: "1.25.x" - name: Check for //go:build ignore in .go files run: | IGNORED_FILES=$(grep -rl '//go:build ignore' . --include='*.go') || true if [ -n "$IGNORED_FILES" ]; then echo "::error::Found ignored Go files: $IGNORED_FILES" exit 1 fi - name: Check that go.mod is tidied if: success() || failure() # run this step even if the previous one failed run: go mod tidy -diff - name: Run code generators if: success() || failure() # run this step even if the previous one failed run: .github/workflows/go-generate.sh - name: Check that go mod vendor works if: success() || failure() # run this step even if the previous one failed run: | cd integrationtests/gomodvendor go mod vendor - name: run gcassert if: success() || failure() # run this step even if the previous one failed run: go tool gcassert ./... golangci-lint: runs-on: ubuntu-latest strategy: fail-fast: false matrix: go: [ "1.24.x", "1.25.x" ] env: GOLANGCI_LINT_VERSION: v2.6.0 GOEXPERIMENT: ${{ matrix.go == '1.24.x' && 'synctest' || '' }} name: golangci-lint (Go ${{ matrix.go }}) steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - name: golangci-lint (Linux) uses: golangci/golangci-lint-action@v9 with: args: --timeout=3m version: ${{ env.GOLANGCI_LINT_VERSION }} - name: golangci-lint (Windows) if: success() || failure() # run this step even if the previous one failed uses: golangci/golangci-lint-action@v9 env: GOOS: "windows" with: args: --timeout=3m version: ${{ env.GOLANGCI_LINT_VERSION }} - name: golangci-lint (OSX) if: success() || failure() # run this step even if the previous one failed uses: golangci/golangci-lint-action@v9 env: GOOS: "darwin" with: args: --timeout=3m version: ${{ env.GOLANGCI_LINT_VERSION }} - name: golangci-lint (FreeBSD) if: success() || failure() # run this step even if the previous one failed uses: golangci/golangci-lint-action@v9 env: GOOS: "freebsd" with: args: --timeout=3m version: ${{ env.GOLANGCI_LINT_VERSION }} - name: golangci-lint (others) if: success() || failure() # run this step even if the previous one failed uses: golangci/golangci-lint-action@v9 env: GOOS: "solaris" # some OS that we don't have any build tags for with: args: --timeout=3m version: ${{ env.GOLANGCI_LINT_VERSION }} quic-go-0.59.0/.github/workflows/unit.yml000066400000000000000000000052421513066070600202570ustar00rootroot00000000000000on: [push, pull_request] jobs: unit: strategy: fail-fast: false matrix: os: [ "ubuntu", "windows", "macos" ] go: [ "1.24.x", "1.25.x", "1.26.0-rc.1" ] runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }} name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }}) timeout-minutes: 30 env: GOEXPERIMENT: ${{ matrix.go == '1.24.x' && 'synctest' || '' }} steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - run: go version - name: Install go-junit-report run: go install github.com/jstemmer/go-junit-report/v2@v2.1.0 - name: Remove integrationtests shell: bash run: git rm -r --cached integrationtests && rm -rf integrationtests - name: Run tests env: TIMESCALE_FACTOR: 10 run: go test -v -shuffle on -cover -coverprofile coverage.txt ./... 2>&1 | go-junit-report -set-exit-code -iocopy -out report.xml - name: Run tests as root if: ${{ matrix.os == 'ubuntu' }} env: TIMESCALE_FACTOR: 10 FILE: sys_conn_helper_linux_test.go run: | test -f $FILE # make sure the file actually exists TEST_NAMES=$(grep '^func Test' "$FILE" | sed 's/^func \([A-Za-z0-9_]*\)(.*/\1/' | tr '\n' '|') go test -c -cover -tags root -o quic-go.test . sudo ./quic-go.test -test.v -test.run "${TEST_NAMES%|}" -test.coverprofile coverage-root.txt 2>&1 | go-junit-report -set-exit-code -iocopy -package-name github.com/quic-go/quic-go -out report_root.xml rm quic-go.test - name: Run tests with race detector if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow env: TIMESCALE_FACTOR: 20 run: go test -v -shuffle on ./... - name: Run benchmark tests run: go test -v -run=^$ -benchtime 0.5s -bench=. ./... - name: Upload coverage to Codecov if: ${{ !cancelled() }} uses: codecov/codecov-action@v5 env: OS: ${{ matrix.os }} GO: ${{ matrix.go }} with: files: coverage.txt,coverage-root.txt env_vars: OS,GO token: ${{ secrets.CODECOV_TOKEN }} - name: Upload test report to Codecov if: ${{ !cancelled() }} uses: codecov/codecov-action@v5 env: OS: ${{ matrix.os }} GO: ${{ matrix.go }} with: report_type: test_results name: Unit tests files: report.xml,report_root.xml env_vars: OS,GO token: ${{ secrets.CODECOV_TOKEN }} quic-go-0.59.0/.gitignore000066400000000000000000000003321513066070600151430ustar00rootroot00000000000000debug debug.test main mockgen_tmp.go *.qtr *.qlog *.sqlog *.txt race.[0-9]* fuzzing/*/*.zip fuzzing/*/coverprofile fuzzing/*/crashers fuzzing/*/sonarprofile fuzzing/*/suppressions fuzzing/*/corpus/ gomock_reflect_*/ quic-go-0.59.0/.golangci.yml000066400000000000000000000052521513066070600155450ustar00rootroot00000000000000version: "2" linters: default: none enable: - asciicheck - copyloopvar - depguard - exhaustive - govet - ineffassign - misspell - nolintlint - prealloc - staticcheck - unconvert - unparam - unused - usetesting settings: depguard: rules: random: deny: - pkg: "math/rand$" desc: use math/rand/v2 - pkg: "golang.org/x/exp/rand" desc: use math/rand/v2 quicvarint: list-mode: strict files: - '**/github.com/quic-go/quic-go/quicvarint/*' - '!$test' allow: - $gostd rsa: list-mode: original deny: - pkg: crypto/rsa desc: "use crypto/ed25519 instead" ginkgo: list-mode: original deny: - pkg: github.com/onsi/ginkgo desc: "use standard Go tests" - pkg: github.com/onsi/ginkgo/v2 desc: "use standard Go tests" - pkg: github.com/onsi/gomega desc: "use standard Go tests" http3-internal: list-mode: lax files: - '**/http3/**' deny: - pkg: 'github.com/quic-go/quic-go/internal' desc: 'no dependency on quic-go/internal' allow: - 'github.com/quic-go/quic-go/internal/synctest' misspell: ignore-rules: - ect # see https://github.com/ldez/usetesting/issues/10 usetesting: context-background: false context-todo: false exclusions: generated: lax presets: - comments - common-false-positives - legacy - std-error-handling rules: - linters: - depguard path: internal/qtls - linters: - exhaustive - prealloc - unparam path: _test\.go - linters: - staticcheck path: _test\.go text: 'SA1029:' # inappropriate key in call to context.WithValue # WebTransport still relies on the ConnectionTracingID and ConnectionTracingKey. # See https://github.com/quic-go/quic-go/issues/4405 for more details. - linters: - staticcheck paths: - http3/ - integrationtests/self/http_test.go text: 'SA1019:.+quic\.ConnectionTracing(ID|Key)' paths: - internal/handshake/cipher_suite.go - third_party$ - builtin$ - examples$ formatters: enable: - gofmt - gofumpt - goimports exclusions: generated: lax paths: - internal/handshake/cipher_suite.go - third_party$ - builtin$ - examples$ quic-go-0.59.0/LICENSE000066400000000000000000000021031513066070600141560ustar00rootroot00000000000000MIT License Copyright (c) 2016 the quic-go authors & Google, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. quic-go-0.59.0/README.md000066400000000000000000000217151513066070600144420ustar00rootroot00000000000000
# A QUIC implementation in pure Go [![Documentation](https://img.shields.io/badge/docs-quic--go.net-red?style=flat)](https://quic-go.net/docs/) [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go) [![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/) [![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://issues.oss-fuzz.com/issues?q=quic-go) quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)). In addition to these base RFCs, it also implements the following RFCs: * Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) * Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)) * QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369)) * QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/)) * QUIC Stream Resets with Partial Delivery ([draft-ietf-quic-reliable-stream-reset](https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07)) Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go). Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/). ## Projects using quic-go | Project | Description | Stars | | ---------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- | | [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | | [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square) | | [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | | [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) | | [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | | [nodepass](https://github.com/yosebyte/nodepass) | A secure, efficient TCP/UDP tunneling solution that delivers fast, reliable access across network restrictions using pre-established TCP/QUIC connections | ![GitHub Repo stars](https://img.shields.io/github/stars/yosebyte/nodepass?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | | [reverst](https://github.com/flipt-io/reverst) | Reverse Tunnels in Go over HTTP/3 and QUIC | ![GitHub Repo stars](https://img.shields.io/github/stars/flipt-io/reverst?style=flat-square) | | [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | | [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | | [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) | If you'd like to see your project added to this list, please send us a PR. ## Release Policy quic-go always aims to support the latest two Go releases. ## Contributing We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. ## License The code is licensed under the MIT license. The logo and brand assets are excluded from the MIT license. See [assets/LICENSE.md](https://github.com/quic-go/quic-go/tree/master/assets/LICENSE.md) for the full usage policy and details. quic-go-0.59.0/SECURITY.md000066400000000000000000000014361513066070600147520ustar00rootroot00000000000000# Security Policy quic-go is an implementation of the QUIC protocol and related standards. No software is perfect, and we take reports of potential security issues very seriously. ## Reporting a Vulnerability If you discover a vulnerability that could affect production deployments (e.g., a remotely exploitable issue), please report it [**privately**](https://github.com/quic-go/quic-go/security/advisories/new). Please **DO NOT file a public issue** for exploitable vulnerabilities. If the issue is theoretical, non-exploitable, or related to an experimental feature, you may discuss it openly by filing a regular issue. ## Reporting a non-security bug For bugs, feature requests, or other non-security concerns, please open a GitHub [issue](https://github.com/quic-go/quic-go/issues/new). quic-go-0.59.0/assets/000077500000000000000000000000001513066070600144575ustar00rootroot00000000000000quic-go-0.59.0/assets/LICENSE.md000066400000000000000000000026051513066070600160660ustar00rootroot00000000000000# quic-go Logo and Trademark Usage Policy ## Exception to Main License The files in this directory (collectively, "Brand Assets") are **excluded** from the quic-go project's main MIT License. These assets are protected by copyright and trademark laws. ## Permitted Use You are granted a limited, non-exclusive license to use these Brand Assets solely for the following purposes: - **Editorial and Press:** You may use the Brand Assets in blog posts, news articles, video reviews, and public presentations that discuss, review, or reference the quic-go project. - **Reference:** You may use the Brand Assets to indicate your project's compatibility with or dependence on quic-go (e.g., "Powered by quic-go"). ## Restricted Use You may NOT: - **Modification:** Modify the Brand Assets in any way (including changing colors, aspect ratio, or obscuring the image). Resizing the image while maintaining the original aspect ratio is permitted. - **No Branding:** Use the Brand Assets as the logo, icon, or mascot for your own project, product, or service. - **No Endorsement:** Use the Brand Assets in a way that suggests your project is officially sponsored by, endorsed by, or affiliated with the quic-go maintainers. ## Termination The quic-go project reserves the right to revoke this authorization at any time if usage is found to be confusing, misleading, or detrimental to the project's reputation. quic-go-0.59.0/assets/logo.svg000066400000000000000000001400661513066070600161470ustar00rootroot00000000000000 quic-go-0.59.0/assets/quic-go-logo.png000066400000000000000000003152431513066070600174770ustar00rootroot00000000000000‰PNG  IHDR óuwÍøV pHYs"÷"÷ÚäòDtEXtSoftwarewww.inkscape.org›î< IDATxœìÝw˜”Õùÿñû<3;[§m§ Hï½7Q±a¢ÆØ±%–Ä^£!h,±` öޱ"½÷¶€ô]`wÙÞfwgæ9¿?R~~efëì̼_×Å>÷yÎùëîrqs¥µ9F¤p¢£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€£™€³F:KNNvŠˆS¤Ö)b¸Ä'NQÊ)b:M¥ìJ©x¥uÜ–»DDi-±J©„†ž¡E—)¿)Êgˆ.ÑZWj¥j ­*µ¡+µÖ"ºØb•ÚjV˜¦µRDŠKKK+´ÖÞÐÿÌàÄ£´Ö‘Î"ìßMcµ•ßÒίL»¡Åb*åPZy´2«-Ú(÷‰Ûl¾ì‚‚ª\­µéÌÑJ)¥RRâÛjoLÓÐÊÔí´adˆÖíDI¦hÝV”Ê-S”kE¤BDJD¤B”T(­+DŒR-ºT‰‰H±)R$¢‹D¤H)]$SlµZ‹Ž=ZÑôÐJÐÌÀ &%%Åašµ§)­Fj‘a"2Pþ=Õ­¡|"r@D6‰R[DôzËââââ’ ¥ìv{ZŒRݵEºk-Ý”HwéöŸñŽ×šxE¤X‹)õïÆ?)S)¥òµ’C¦©r-ó×+yåååG#œZÍ|œ”Rq.—ýR­åb%rªˆÄ†ø¿ˆlT"óýZ}ZVV¶FŸ ßd(¥¬nwb/­Á"z°˜j°(é+kDÃÕˆÈѲC)ÙiŠì4LÙáÙYVVVépÐT4óps8©†!w)‘kE$5ŒGP"úµzµ´´ôÇ0žÛ¢”RÊíNêcšj„ˆbˆ¬Eú “öZ-’«”¬S¯Qb¬õšæÚòòòüH瀆 ™€ãRÊærÙo-÷Kd'Äi%zž˜ê¹âòò¯¢qZ_ŠÝÞ[rŠVj¢ˆL‘´GBãP¢Všbþ`šê‡²²²]‘ÐÌÀqÆápt·ê=0ÒY~f2åᢲ²/#¤..—«³!þ3DŒ‰ZôDɈp¤ÿC‹T†Qn±Z*”%¦ZDİÚj ‹Å/"bKHª1l1¦ÕoŠ)R[Uf ¶Ÿ©MåóTÙDDü>oŒöû¬""~¯7A›:F´?Ak}ü\¬å°²@D~PÊ:¿¨¨è`¤#€Í|W\.ûT%êuIjæV^¥T¥Q¢µ3ÙþGËe1o)*ªØÒ}›Áår VJO­'‹È€pž­”*6bl…¶„¤²8—Û“˜œéKjÓAbÓ2ãâ\)ñqNw¼-É‘›ä´[bl±áÌ÷_¾O¥·¢¢ª¦ªÂS]V\Y[YZ[[Q^[[Qêó”úª+JLoy™öVUªšŠR£¦ªÒâõ”ÛüO¼¯¶&Qû}­µ[DT$ò×a“õ¥Öò¯²²²5Zk3Òœ˜hæà8áv;®-¯ˆˆ¥A/(UŸœº3³×àŠÌÃãR:÷JNÊìjKHJüyØß[[S‘8¿,÷@qñÞe¹[Öù‹ìLòžÜÄ©m>Ñò\B’ýþC‡U5áýfQJÅ8‰”R“EËy"ªc çµXcŹÜGm;W¥žÜ[¥vï—äîÜ#-1­MšÕׂg·:5å¥%ž’¢²êÒŠªâ§0¯ÆS˜ç«**0«K •§´ØR[Yf«õTÆû/./Ÿ«µ® Ó¹@3Ç·Ûy¥hý¦Ô3õL‹T§tîµ¶ßE¿Š;iôéý-11¶æœ«MÓÌß±i×î…sòö/ùÎ][QÖWDŒo d»Ræ¥ášÒçt:‡†\%Z_&"É¡Þß°ZÚÓÛJï=¨¦Mÿ‘‰)]û¤;ÚvlgX¬A¯¹EýjÊKK*Ž*(?œ]Vœ½§ªpçfqöÞ8OQ~†ßçm'"1-pl‘ˆ|ªÅ|¿´´r¡ÖÚßgÀÿÐÌ@”Kq8F™†ü "A¯_Õ"Õ‡O\9ææ‡ûħ¤¥µT–ªÂ¼¼M¾²ãǹŸœä÷y;5ä-Rmˆ¾£¨¤ü¥–È”–˜ØÆc™¦´\-Jú„j_óמÑ!7£ïo‡ÁcýFtŽM´‡öJbÔËôù¼¥9{³s·®ÍÏY»¸æè®mΚŠâ“Ek{¨ÎÐ"¹JéLÓx»´´tu¨ö€Ÿ¢™€(æv»]bú·‰’¶ÁÖØ’œ›'ýùM{r§îÃLk½ÙÜ Ë_|̨)/ØÀwž))«¸Kkm6÷ø_£›4Y)uµh9CDš=ϰÅîMéÜ+§ó¸ÓcºŒ=«kK6E¢y´iš%ûwíÏY»äÈÕ ýE{¶ulhsil¥^óûõÛeee…!Úhæ š%»/k‘ëƒÕSºö]rîŒwG+Ãb g®Ÿ:¸ú‡MKfÜc­­ªlÈT¼OJJ˯ÐZW7å¬ääd§iú®W¢nÑš²ÇkvJ×¾û:=ÝÚyÜY]RÒÒ›³"«òè‘ÜÝ ¿Ü³ñ7º${wÓïÚÛ@5JÔ¿–×ËÊÊæ…¢ À‰f>¢Ô®×]&"*P½ýో~ùðKãE©€õpÒ¦i®çï˶|4«Ö:¹žÕóJJ+ÎÕZ×4t—ËÕYiÿm¢Ô5"ÒÔëUuL|ÂŽŽÃNÉë=yzÛÔîýº7qD¢=Y{¶ùvö• Rj+ÊúˆˆÑŒív‹R/(ey£¨¨¨4TœXhæ J¹]ö¹"ê—jöŒö+§¼úíˆÖÐÈ÷SU…yy_Ü6å°§´hP µ|QRV>Ekí­kYŠÓ9ÒTú.¹@Dš2}°&>%mK·‰çVõ<ûòî‰im2›°¢\UñÑ‚_¸c÷‚91åù‡(‘¸&nU¡EͶøÍç ËË·‡4$€ãÍ|D¡‡cô¦òËbäM{gy¼-Ñîs¬Ѧi.}æ‹wÿðÅ8©£O‹~¯´´bšðÍŠÛ4PÄx\´œÝ” ¶„Äm=NŸRØoêõýb.wSöÀñ©¶²¢|Ç×ïlÊúò½øª¢£¥iM¢ZD}¯E=]RR27ÔŸhæ ¹]ŽDdJ Úð_ß½¼ÏùWs¤FÛ5÷“UËž{°¿ˆÄ[£Dß_TRþÄÿÛétvSJU¢/–Æ^‹ªŒ¼öƒÇìzåíÜ]zvirpœ0ªŠlùäõí»¿ÿ4¹¶¢¢oÓvQ•–§‹ËÊ>ÔZûB›Àñ„f>¢ŒÝnO³ZTŽˆØ~^³ÆÆïšþÑšn­ízÝ`mX¾eÞC×·ÓZ'Yb*S&«ßFí·>¨E~%"ÖFa&¥·[3äÊÛ,Ç9H–¦LY¤hOÖžÕ¯?“»eu¿:>^ë²_´ú[BRÒk‡ª y@Qf>¢ŒÛí¼]´ž¨6âº{Wô>oú¨pgj޼ív|}ïôv¢µ=È’ -bU"q ßUUfö²vÔMwvuèÜ1$Añ{kkv~óÁú-ŸÌ²U ’ÆNˆÉ×"µÙâþ‘ŸŸ_ÑD'šùˆ2n·c™h9æ]¥Œ‚+?Ýà4¬Ö˜HäjŽÃ–oýîÁk»ˆ¨„fmd1ò:üåŽÑ7?Ø?Öár‡(Pé¡ýÙk^}roÎú¥´Ö®F¾^ J½ ”efQQQi‹Uhæ Š8ÎdCé|9æºØvF-:ýñY"+$²¾|wåÊ—Ÿ!"¾"ذÅìzŹ½Î›6ܰXs /Ðl^OEÅúw^X¿ó«÷:ú}ÞN|½P‰üÍ—ðlnnneKähæ Š$;—i¥ß T;ó±Y[Û Õ7Ü™Biîý¿^thÓÊ7$k΀KoÚ?àâëF)ÃrLƒ#VZë½?|¹n嬧¬5åEõ®’#JëG‹K+fi­}-”@+F3QÄít¼%J®üùs¥TÑUŸorF{C›éóyß›>nGmEY¿:ZŒüÞç\±cèUwŽ´ÄÄØÂh°#[Vo[öìååy9ÃEÄhè{Jd§)ú¥¥Ÿj¾QN(4óE\.Çv%ÒëçÏÝ'u]vþósÆD"S¨•dï=ðÙÍ禋H| zzþKÎ|â!–ظ„0G­hÿ®}KŸ} §p÷Öá"ÛÐ÷´ÈJ¥ô=ÅÅå‹[0€V¤ÁSB@deff&*‘îjGŸf†;OKquèrÒÉÏY¬^Yt4–F>D‹äNÝ;Ÿ7óƒqSg}W˜Ösàb©iÈ{Jd¤hµÈívü+9ÙÕ×gh&ó%RœÎ‘¦Ò+Õ&=9;+³Ïàc&ö…’¿¶Ú“½zQVΆe奇ö)í3ÅÙ¾³Îì34¡ÝÐ1Üi©!;Ëë­}ûÒa9f­·K€²>ç©wv¥õØ#TçáR‘—sxÑßîÝ¿}Ãiø¤>¿ˆüÓ°ø*,¬ÊnÁx"ˆf>¢„Ëå¸N‰¼ ä½òÓ ¦%ÆÖà+<ÃS\P¸ðéßmÍݺv°hm²Ì—˜ÖfýÐ+n3ºœzîÐPœ»mÎ[+VÏzjT Zr§îK'?÷ÙØPœDBE^Îá…Oß³ûèΣD$¦¯y”Èß­¶¸Çóóó+Z2€ð£™€(áv;­üùs‹-vÏ•Ÿ¬?¹%ÎüqÞçk–=w7­µ«¡ï$¦µY=é‰7ÛÙÛ´oל³µiš³§Ý篭 ôsó\öîROœÝÜœ3€H+Þÿãþïÿtknù‘ì"¢ô’–ÃJÔï‹ËÊÞÓúf^)eKv8›†ŒÐÚ¡DzŠ(—ˆüç‡öˆ¨bÙ­Ee)-Ë «÷‡ÂªœHä¢é a”Öm=·%$•¶ÄyÛ¿˜½béßÿ8¤1|""•G ÿä†Iq¹[×noÎùÊ0ŒÁ—ÿ&7H9~ã{/miÎþ@kàîÔ­Ó”W¾yæc³¶Ù’œ›ô’’¶Zéw\NûB·;i@ Gü?œNçðd—ã%—Óžo*½B´~F‰ºLD ‘Î"â%¢D¤ˆLP¢o¥ÿiú­“]Ž•É.ûMéééIáÌ DšùˆZIÀf¾8—Û곊öïÚ·òÕ'I¿WÐÚLù澫O:¼aùÖæäèuÞC•REj{~ø<¥9{­I›£úN{oyÿ1·>ºÊ°ÅîmàkãEëÜnûs.—ËÝ’ùÜnÇd·Ë±ÙPz•¹ADœMØFi‘ZÔ‹ÞÚê·Óþ`ZZZ°«»€Í|D ­%`óZbr¦/ÔgÍø†%×¼]tâ¼GnL®)+)nê–[lFß!'ðÕVTô©ÈË9Üô|@ëÓýô‹FLÿpuÇþ_·D”‘×€W,¢Õo”˜;].ǵJ©~ïr¹»]ŽDËç"Ò/„[;E©G|ÞšíN§óŒî D-šùˆJÄè¹%6V‡òœ’ì}+ ó‡Ö±¤Æ–´-ÖîÚ¤”QP×^¦ßßö›?^Õ¬ëv_q{F’Ú9÷ã=ÍÙh ‹Õ:dúí㦴:éäñg/¥ÊðZšyÕå´¯r:#š›A)¥ÜNûJÌ5"2±¹ûÕ¡½¡ô7É.ÇK\½ €Í|D€_·• Øã×d?.øì€ˆ¨@5w§®Ë.w¹gÚ«ú\þî²W}¾)yÂÝ3ÖYcãwÛ¯xÿî‘%Ù{ö75OF¯=‚M(;°tžµ©û­56>qüxé[ kÓº÷[""f^j(½Üír¼æp8R›r®Ëår»\ö/D©G%<^PZäoMõÊ´´Ä¶a8h•ø p¢GÀ;e|ÜdÅ{vl2,–ÓŸýt¤2,ÿëT†atwæŽ#N©þôƳWW=2<À«–eÏ?|ðì¿ÌîÔ¤@J©äN=~,Ú—uÌ„¾òüœvMÚˆ"ñîÔ”sf¼?.wÛú¬ùÞä÷VUô­çCD®±ržÛí¼«¤¤l¶ÖºA<ÓÒÛ*±,-ÝÏgKØzrÏ|GÇn:Áj‰MrZ<…¹Þ²#9fñþ±eGöÔZ»ëÝIIŸ×²Äår^RRÂÔMœphæ Jho ¶=¯§*¤Ý|åGÄzž˜Þö 2,§fYm±q¼ðyïw/“cú}í^/ع¹ss2uûå–U¯dóÜ0”¿9ûÑ$³Ïà^W¼¿Rox÷Å¥›>x©§Öf}“÷REë·\NÇ•N§óÆÒÒÒÝu-v8©e™+ªþF¾˜„¤­}瘝ïùWŒ‰Oê)"=ƒ­Õ¦ßŸ»yõÖU¯=UZ¼×p ø9æ?º(e.q8ËÊÊ‚NüŽG\³ @”PJ =¯**é=»«5à/í7ëlŒ‰OJêsÞô½j¦ßסôðÁì¦fêqæÅC”a=üóç'>ã@S÷¢’Rjд[Æ^öÞò˜Ì>ƒ‹ˆ¯þ—ô/ ¥7'»PJl¤KIIqX ùV”ô©k'[BÒ¶3›µõŠVõtÙÍccâ“’êlX,mŽê{þsŸ™òê·y.«;®´1 ù<--Í^ßÞÀñ„f>¢…¸™¯¦¢ÄÊcb’N»«.-ª÷šÌ^ç^t¢WYöî¦f²ÄÄØ&Ïü :&>á¿ãù<íŒZ4îŽ?oêž@4‹M´;'=9{ü¹3Þß›äÜÔ€WâµÈ.§}}ŠÃ1êçEÓçý»ˆ ©ãýÊ—Þ¼tÚû+{·8ª¾k~ƒ²gvhþósÆŒ¾ùÁUJ?§‰ˆ(‘^>oÍëJ©ÐÞ#´b\³ @”PJ Ì«­( éôªäÎ=tþöõÇ<÷ÕxÚ™~ŸÏ°Xƒ~ÿ˜Ú&SDjEä˜ÃÊâüêæärwéÙåŠ×Hmey™%&&Æb‹›Ðœý€ãAj÷~Ý/w™ÞúéËÖ¼5³«h3£žWúš†,u¹œ/†åEEE¥n·c²(¹*Ø †Åzð¬§fW¤uï?6T¹{LºdDÛAcvÓ9ûý>o§ ˦¸\ŽßŠÈ³¡:h͘Ì@´Ð*à•²^OåIÚ4ÍP“Ùwh°æÀÄü¬»ê|Y)eØbr•⩱ÍÍ&"bK´;,¶¸øPì”R}/ºf̬HlÓÄ" 8]ó' %úfí÷mw»¿-/]h‹Ù{ñß'¦uïß;´¡Eì™íÛNy}^’56>øç­±Ûíé¡>hhæ JhÃܸ¢Jï?ªs: ×CD|jY_½›_ßû'øÅ1Y ‹5»ý±=C@1ñIIg>ñú„³Ÿz{§->!È狟PÒV´¼."§ù)Ú{áó_ÆÆ»SSBõ¿Üi©½üKÖÜ Kœ‹ñpK´&4ó%ü~µ%XíȦUÁaÍŸ˜hKHÚ¨¶ż^~¯·¶®÷ÇÝñäè6ý†-ˆˆ5.aÇ9}»šiz@x¤÷Ôû²÷Vôè~Æ”E"ºª‰ÛøÏ|âµ"{›öíB.€„”´ôÓ|>_D¼êJôu)v{È'­ Í|D‰²²²-°ioïâ¯ël°k¬¶G Xð›KfÞ»¼®w-11¶3ÿôæ„+?Ýh¹üÝå%Ó?ZÓ3¥k¿n¡Ì n†Åjó›G&\øò·… Éiëû~»£–fö¶ºöCÆõï0d\°Ï-VÓ¢îW Rhæ Š(‘žíÞÒ!”ç ºâÖÎ"¢Õö-ùvü‚?ß¶ÐW㩬kKLŒ-Öît…2€Æq¶íØá’·võ]Ë”2 ú^—SÏKhÉ\Œÿý_+ÃÜH,rYZZš=¬€0£™€(¢DÏ ôÜWëí\ž›s8Tç¸:t9ÉžÑ~U²q`ùü‰oOîÙþ¯w6h]ú^t͘KßYl$wê±T‚4êþÔ’™÷ ^2ó Eëz׆Š-1ÉÞóìË·)'úk«ÏW hæ Šøµ1?XmÓG¯üʳ~ñ‡gÓµHu°ºÖfêªWþ4¤xÿ®½¡<@ˈ³»“'?÷éØSÿðìFÃbÍ®g¹e÷‚9?üõikjÊJŠÃPD]zS© TÓ†œ®@$¨0Û!àv9ö‰H§Ÿ?·Øâv_ùɺ®¡çŸãëZÓí´óŽ½í‰‰¡<­Ÿß[[]™—“ëÉ?\R•Ÿëñ”xkJ‹ÌÚ²%"âóÖˆ¿ÆcÊ¢cíZD$&>IǺ“8Wš5.%-6)£31³}¦5.1)²?›¯ÆS9ï¡ë×ån[?Vêù>†ÅšsÆã¯•göÚ+Ù>½á¬¥‡Œ:¶¢—”ŸŽ @$X#4ŽyO‰Ü÷óçþÚê®yÛ7ìÈè=¨g¨Îqí=ã˲÷.ÊY¿tB°5ÊbU¡:­¯º²¢ k㞣[Ö”ïÝnTæå8}•í´è ÐTÚ$J•Zcã'¦µ-tvîáOí;$1½ïð“mv§;$ûãÖØøÄIOΟ³fá¦ïÿt›Ûôù:[kú}í¿¹ïªša×ܵ¼ï׌nélÆŸéÛôþË*ª£ÝnO+//?ÚÒ€H`2Q&Õáèé7$+PÍÝ©ë²óŸ›3&Ôg®|ùO‹³¾|{¨ˆJøés-R}ÑËßu¶íØ!Ôg"2jJ‹ ­˜·+gù|_Ù¡½mÌÚÚÎ"b‰Dó?±mûCíGþB·}ÆÉñ©m"‘ãxçõTT|u÷´MÅûw×÷¹Cw=uòüqwüé—-™'wÛú¬oîp  aʘ²²å-y>)4ó…Ü.Ç* ä»ð¥¯Ž8Ûu ys]Iö¾ƒËž`~ÖÆ¢µÓ°X³ÇÝõ—ü.ãÎê³^Å{wìÚó͇ó6¬Èô×Tõ‘V9mѳ7µ÷ì“']œ’Òkpoeu^‹ÆÙòѬekg?3@´®óÚãöƒÇ.úå#/ÖÙ\~omÍ?/d“@‡J_T\\þiK DÍ|D!—Ëqy%P-­{¿%çÌx\Kmú}¾Úªò²8»;¹¥Î@Ë«Ì?”“õѬݹkw2ýÞN‘ÎÓh†%7£ß°½¦\×ÞÑ©ÛÉ‘Žs¼(ÉÞ{à‹Û§xýµ5]ëZ×aÈØ§=üò©-•ãsûÔŠˆí˜‚VW—–În©s€H¢™€(¤”Šu9ì{EIÛeß¹3?ÜŸÚµOÍ88ñh¿Ïw`Ñ—kw~òzlMYÉ9.&ÛYâ·u;gZI׳.jXcb#'ÚÕVV”þ›É;* r‡Õµ®ÓèÓžrß̉-‘ásûxD$þ˜‚’_—½Þg‘F3QÊårüN‰<¨kwmºü¥ýE©Vy]*ÂËW]Y±uösë²—|×Mk ÐæÐÃr$Éa?ë‹óÚív_bR’™˜˜¤Ýn·**,²² UYQnTVVY=žJ[UeUR¥ÇÓN‰$†*ˆ£ Í°q[\sWÿ˜$“#›A›¦9ÿÑ›—ä¬[Rçuº/¿ei¿ 5hÝ[Ï®ËÛ¾N¹Úvò»îž~ñîÔ”¦Ÿí÷¿9¹¿)"1ÇÔD_RRRþaS÷ M\™ IDATZ3šùˆRéééI^oõ.ÑÒ&P}èôÛ—õ»øº1áÎ…ÖÃ[UQ¶éë¬Z0@kínî~V«5§s—Îúôéãíß`lÿARzôèÑ>>>>¡©{æääÞºeKÞÆ Ê7oÚ¤³²v¤”—•öQÖ¦'UéF¬tÝ}}cî&7•AdÅ?[´ãë÷Ç‹HÀÆ`-RmKHÚá­ªøßg†ÕzpÊ«ßÆ&¦¶ÉhÊ™%Ù{|vó¹'>Ï8½¤¤d^SöZ;šùˆbn·ãW¢%ð•“J•^ôò×eŽ6;„9"LûýÞŸ¼¾r÷¿Þé¥ÅLmê>‹%¿}‡ö{&u¶oÊÔ):uî–%OUUåÂíüæë/ËW¯Ze/((ì'¦´ÕG‹Tf¹vè-µÄ'„là‰fõ¬§–m›óÖh ÐЧEV(‘Q?nÏh¿rʬïF6å¼ç}¾féßÿðŠ_-F—’’’}MÙhíhæ Š)¥ —Ó¾ZD†ªÛ’’¶^öö²ž†ÅÚŒ)gˆ&}»fó›OehŸ¯cSÞ±Æì?íôÓößzûíºwïÞ)Äñ𤰰°ðÏ¿°íã?L©¨¬ìÓè ãHßËnÞßåÌ‹i:Cì}í©Å[>kœÛзPD&zgÒŸßÚ–Ùwh£¿æ=tÜõKíYYRZîÐZ›Ýˆ4óå\.×`%æJ 2¹¬Mÿ‹Î|âõ aŽ…0óäYþçÛTænô44Ã0ò‡žõ›ßþ6uÔ¨Qo– £;wî{fÆŒ?,ø¡›ÏçkטwãÜ©kÇÜ÷LF"Ó*›dÁŸ~»ðÀŠï'þŸ‡Z–‰’€×y§œÜ{ÉyÏ|4®±ç¼uþ€lÓï ô{´°¸¤ì”ÆîD šù8¸Ý·E뇂Õûœ{ÅÂá×ß71Œ‘Fûæ~ºbëÛÏöÒZ»ó^LŒõÀ¥—O;pïî×RùZ‚išæg²îÑGIª¬¨ìÕˆW=N»hMÿ+;N”:æÚX§MSrý™kËóýô \¯ˆÎQÇ4ß)Ã8zõç›Só뜻uíöow šùcQIÙŸšˆ 4ópPJŸœöäº]ñ¾ùÁµ=&]2"œ¹Ð²jJ‹ —<|ÃÜFý¾:ŽÍ÷Ü÷‡ò©—\_‹¢œÛmŸ¨EVWªÞ+F-‡6­œðö%ÃK6¼÷âRmúý!Žˆ&ª­(+™×%þšêõ­ÍÈÈX»jíšØ3'MŽlÑÆbµZ_šõÊØÏÌ\§”*­k­¯ÚÓkþ—æúª++•/Z ¹úövJ$P#°î4îôÚ@ïøk«=_þ~ÚâÕ³ž%"‰ÁöÖ"(--]ª¬@kG3Q,===I´zM‚|MWJŸÿÂå.»q‰ˆl¬ù/Óïo»ñÝÆÎž:lÿÞ%ß®k‰¼h8_Ueù÷w^|¨!|ýú÷[´xÙÒîäääpd‹f“Ï?è's>/°X,ùu­ózªzÏ¿ã²Ý~omu¸²E£”.½»k'Q²ü'MY”³n¹õ§ký>¯wí3–þsʰ²£;6ޝ{g]¥”u8)¡O ´NJë:o­˜Ëå|A‰¾9HÙ7ñž›;=s°ˆHκ%›ç=vs¦øÍôl­\vãÒÁ—ß:.tiÑPÚï÷νí¢-5¥Eƒë[:ù ÌøÛß~–`Ç‘}{÷æœuÆ~¯×wR]ë3Ú­üÅÓïŽ¨ëºØÝ·O]R¸gû8-’%"¥JT;Ý!&!ië´÷–÷Êݼ:kËǯÞ²¦‡6}™ÜÞ/"KE©Ï½^ÿ{y-ðSZšùˆR.—ýT%j¾ˆl2êyÖ¥‹FÝôÀ„Ÿ>ó”}yç¥?VäQßþJ©¢«çlqÓÄ~KºaañÞ¬‰õ,«ùý½÷¬½áÆÇ„#Óñ(??ÿè/&œRèñTõ¬k]›Áã»ã‰‰aŠuV½òäâíÿšhÒžW)U©µv…è(¯ˆ|¥Ly­¸¼ü[­µ/Dû­×ì…233•¨×$H#_BrÚºQ7ÞLsM¼39yêksG ¸ìÆ%J©âºÎÐZ'{ÊŠë\ƒÐÛôúÓ‹ÐÈWýÜ /n£‘¯yÒÓÓÓ–­ZÑÖîpl­kÝ‘õK&üøõ»ËëZs¢òÕx*}µž`ÿ:(&„|""1"r¾6ä_.§c¯Ûí¸3--ÍÂý€ˆb2QÈív>"Z?¨¦”*™úÚ¼êÄ´6u^gé).(œûàuYEûw‘Mʰ¾zΦ¶!ŠŒÈ[¿t㪙÷õQÖ:–yŸ|ê© S.ž:¢LrrrGmúvˆH| úðkî^Þç‚«F7t¿üw,|òŽòÊÂüa?yìsËC뻟yq½×ñ"4<…yyóï¼ÄЦ™VÇ2ÿý>¸êêk~Õàß_4ÌÁƒœvÊ©†é÷g[cØl»'½øe[Kl\B8³µåGr-~ö¾=ùÛ7 ­C1ÏìÒ¥Ëjmj9˜ÝÉï÷ÕÙ€\2Ñú騸ę¹¹¹•!È„Í|D—Ëùž}i Zbjúš‹ßøaX Z}ŽlZµmëgoXblºÿ%7´OíÚ§kó’¢Á´Öso½`}uiѺVýööÛÿööÛ'„-× fãÆ»§^pa†®£Q-¹{¿Åcxá˜+¬gUÅG =ý»m¹[ÖŽ”ÐLáówïÑcÕk³^ÏìÛ¯ßÿ&fg<òÕW_î÷wýÛ¶míišfj£wVrDky¸´´|–ÖÚ AV lhæ Š$;c´!K$Àµ¸Z¤ú¿ÌwuèÜ1ÑÐ [fÿ}Ѿ¹×Ù¤wÑÔ©‹þòôS4òµ°o¿ùfÃonº¹ˆØ‚,ÑCo}tcÛá…3W$økª«͸{õ †‰èÄfnçËHOß|Ù´i•7Þxs¯ŒŒŒ:õLÓ4—,^¼ýùçþ^¸`႞¦ß :11%²J+ó†ââŠMÍ‹ „Í|D—˱\‰Œ Të4öŒ…§Üó·‰aŽ„fª8´ÿ‚{¯Ì ×&‹ˆ´ïØaù‹RJÓĉÐ{è½3ûí “†ÅrèÌ—¾tZã“™+œŽlY½mþÃ×'øj½›¸…×îpì2hpáyçŸ7yòùÝ“““ÝMÙÈï÷û?úàƒõO=õÿ¾ýû†ˆHL_õ)‘™ñ‰ö‡:TÕ”³€p¢™€(át:Ï0”þ6PͰXs¦}¸*Õj‹‹ w.4‹þî–s7Ô”•¶Àf³í[µnmªÝnzõ+BKk­Ç³æÈá#íIë7lѨ»gw“M¿Ï·ø¯w/Ý·ô»±"bmÄ«•))©;FU|ÑÔ©‰§Ÿ~FßøøøÌ8pàè]wÞ±gÁ‚ïûiݰiZ$K)óòâ⊡΄Í|D‰º¦ò ~û²~_7&ܙХTÅÙ眳p÷ž}•ë7n9}ú•ÃÃÙÈ'"’œœì~ãÍNÌÚ±«zØðá‹EÄWÏ+±¢õß\.ûÉÉÉÎpd‹f>¢€RæCÁj#o| D_Ó£ÌúW_£M3ØÔ79ùä“—O»âŠ‘áÌ„ÿkÒÙ“?nI°zMiñàC«¬g¦–»yõöÏo™œdú|ëYê=å”SîúqOíÛo¿;1%%%9,ë‘‘‘:wîüñË–¯ÌÎHO¯ÿ÷BË9¦é[ép8z„!Ð(\³ @+—œì«M ØP“´õŠV5ôJL´žÂ¼¼y·Oµ‹HB ºRR¸|õj–––æhø™ªÊªªAýû—ùýþÌ@uk\|ÖY¯|ÛS”RáÎ û–ÍݰðÉ;º‹Hb]ëâ·|üÉߘ‘Ã…)Z£i­åÙgg®yô‘Gºj­Ýõ,/1´º¬°´ôÛ°„€)>´rÚ”»‚ÕF\{Ÿ'œYëÿñØN ÒÈ'"r×ïïÞI#_ë˜ð»{îÞ¬î«öô:¸è«ÕáÌ*ù;6îXø—;»IÝ|¾qç\øÍû²Û·æF>¥”Ü~ûögíôtR§Uõ,w™JÿËív^–p@ÐÌ@+æt:»‰ÈyjÖØø]ÝN›<4Ì‘ÐLžâ‚¼Â›‡«§§§¯½ñæ›F‡3êvÝõ×v¹Ý›‚Õ·½û|R8ó„BéáƒÙ_Ý;=Y´š])Uôø³ÿXòÅÛoNJ²Yë›t×jdff¦mÜ´yÄ=÷Þ7_DWÕ±Ô*Z¿ér9~¶p@hæœ0”R6‡Ã‘âr¹º¤¥%¶MKK³G:S}”’Û%È×ë!WÞV­W{žÈ6¾ü§"¤\ûîïg„3æ•Y¯ÆŠˆ?PÍë©ê“»aYÐf¿ÖÆS\P8ç–óüâ7Óƒ­±ÅÆýøýеGn¹jÚ)áÌJ÷Þ{ßisç}ŸkÛ_Ç2¥Džr»†+ŒÒZG:!år¹ÜJ™ãDÔ0Ѻ§h饕œ¬Dâ,÷ˆÈ"²K‹¬²˜²¬¨¼|­ÖÚæØÇHNNvjÓwH\iX­¯üd};eX,ˆ†&ò{ª*¿ºa’_´vª2dñ‡Ÿ|<>ܹÐ0çŸ{î’­[¶Ž TKHÍ\uÚÌG„;SciÓ4ß»bÂæšò¢ÁÖ$$$f­Y¿ÙÞ63­}8³µ”’’’Š‘#†ïÎËË ús­,.-,L±€c0™p\p¹\ƒ].ÇŸÝ.ç%fh™#Zß/"SDIŸ |"ÿžÖ_D¦(‘§MC–»œö<·Óñf²Óy¶R*rÍr~ï ‘OD¤ÛS÷ÓÈ}²>™µ.X#ŸˆTÏ|öÙna „F™ùì³$Èt¾ª‚Ü¡•y‡…7Qã-}î¡Åu5ò%Ú“¶nÙ–•q¼4ò‰ˆ¸\®¤-[·õ2dÈ’:*õ¨Ûí¸3L±€cÐ̈Z)) ÜnçÃn—c·s¹WD”æ}s‹’«´Ò_ºœI{“]Ž{Ngr(27†Vê†ÀU9túmuO˜B«´ÁA¤F½º]ûvm™Ó¹K—'u:iu²eû/üÖ@”·}ÃŽÝó?¬n³Åì[½nSr²ÛöÏw--&&&fþ÷?Œ»búôE"|4¹–¿º\Žë— øÿhæD—+én§cŽé·î­‘“[î4ÕQ‹üÙPz¯Ûi0--ÍÞrgý)Çh騖Þ{à:[¢=Øt7´RG·®Ùjzk»ªi‘ÊÏÌìîLh¼?=ùW°ZîÚåݵišáÌÓPþšêªoø•MDlêJ©Âù —zÛ¦§µ s´°zî¹&üæ7·.“à }J‰¼èr¹~Î\€Í|€(¡”Rn·ýB·Ë±Z‰1_”œ'"á¼fÖ)J=âõÖìJv:/oéÃLC® VyÝ=Lo‹BY¾Z¬6zÔ¨µéééiá̃¦1rD¯ää”jZûÛæmZ±%Ü™â‡?߾Ƭõl&ÿ{￟ݯw¯îa !=þÄØ»~÷û¥"¬ñÒªÄü0ÅnïÎ\Í|€VÏ岟êrÚ׈VŸˆÈ°HfQ"™ZéwÜ.ûÜ””„ W¦6ë ¥âDä¢@5k\ÜŽ”®ýºµÄ¹h9¦ÏW[²gÿ`õ?>p‡pæAóÜwÿ¼Áj»çÌ. g–†(ÏÍÎ9¸nɈ`õ³Ï9gÉgL:¡®î¾ÿþÆÝõ»ß×5¡ÏeZÔ—‡#%œ¹pbSZûû+"+99©¿iQ"g6i‹‘—˜œqÀÕ¡[UZÏþ1‰íãâ’ܱ–ØX‹ßSå«,9ê);¼¿&oËZ)ÍÙ“ê«®ê.›öW(J~]\\6§Iù‚p¹ì)Qªõ<û²E£n¼B(ÏCËËY:wíú—¨æt:6¯Û´)h£Z¿ßoöîÞ£Àï÷§ÿ¼¦DžóÖ—2,áœZ§Oo>wyiöÞÑjv{Ò¶={t‰‰‰ w®Öàºk¯YøñÇO º@É—%%åçiþЀ0°F:?—–˜ØÆgµüY”1]5nЬiKrné2þœ’.§œ×>¹KÏ“E$£¾—´ˆ/§¨z­§¨ äÈnÏ[þ¯ÄÚ²¢þRÿ×ÉÑò™Ûí|¢¤¤ìÁP5{(QÓ‚”¼.¹±W(Î@xí™û‘'XíÚë®/g4ŸÅb1† ’µzÕêcšù´è”£ÛÖmIï7¼_$²ýÜáM«¶—f都”ªœ;oAâ‰ÚÈ'"òê¬×'fçZ¼jåŠñh9Çå²ß)"3› '"&óZ ¥”áv$ݤ•zBDœ ~ϰæv™pÖŽ~—ÜÔ#)­m›Æž›WV»ÐSãŸøÓgÕ‡ï|ãáÝå{6‘Ølóq’½üÊìl´i«!\.—[‰y$Й‰©™k.~ãûˆ^3Œ&ÐZÿëêS ´i¦ý¼¤”*Ý’µÝ‰hÑÆãñTíÌÊÊÎÊÊ*É=’W›ŸŸg*eHzFšÑ­{ÄÁƒµmÓ¶mf8²,[ºtëUWL猪9hÌÂáwþyb8rÔçÝËÇn¬)/x…îÔ©/|åÕYéÕ1MÓìÓ§×úÜ#GNϯ©Õ¸ÒÒÒUa €“ù­‚ËåìrÚ_Ò" nV³ÆÅíè{þµGz_ð«Q†Õ:±)çÖøÌ=žÿ¸Ÿ?KmÛvÀï_i[[”››5ë]å{·Œ•º§N©(w¸333'çææV6%Ë¿™S$Hó`ß ¯ñ5}_DJÉîí»´iöT4hЦ¸¸¸ÀÁ ^¯·ö믾Ú<ûÍVíØ™Õ¶ÚSÝEDþZþ—ÍfÛ?dèÐ÷Üw_»¾ýúvm©l£ÇŒécµZù|¾v?¯ݾ®ÑMÅ-!ÇÆÁù,ë‘¿?÷Üðpgj Ã0–/_ÙµG÷®Ù^¯·C€%1¥ßRJ ÔZW‡= N¹º€KIIq¸ög”˜«¥|ñ®äõîþÛÆKßYÕ³ïÔëN1¬Ö¸¦žŸ_â-K°º-93sÀݯŽïsë3Û kÌþºwÓ¿¨ñT}“™™™ØÔVJý?öî;®ªûüøó=w_¸ "Š Á½pYÆh†™¦YILšÄ,Ó¬Ö$Ín³Û¤iV“4³Ùš˜ÄAÜ(** *CÙëÂçûû#±?£ç0.—\>ï׫¯W{žï9çQ‘kê'Ï£én?&“iO“ªJÞ§Òê°ŠuªÊÙ&îl;oîÜDE›éçœN§ã®;—¦Ï˜–â¿aÆ4Qä=y^uǘ믹vôÅ .Ìlhhðz¸î÷7Þ(ןP™µñ ·ß×ö†ÚÚšÂ<Éïoz½áð²e÷§(ÝS7iÒäQËî`—\-³Z­•ì „ù@qC† 1Ù¬æwˆÓ·DtÆŠÊÓ1A]>vÑméW}²eDôŒ¹^ RT7;»µn^Ðè É÷½ž:tÁÍ›‰È!wŽo5›Þén?*[HD’=%]v“wŸýç¼­¾&Vª4ä€Á€Éh'O3¶ø‡ï¿OcDO·”Àìß?kêÄIM{²³½°›–’GDv©ZÙŽŸ%¯+eç{/í#"ÉðK—ÿ áÏ^>+<"PcLcµšŸæeQLgçµþ–½óŸÿ¼dÎÃOSiõ^Ÿ^VÝÔ½©|§ 3kÜØß9AŒÕËâôW›Í”Ú•çu´b7hÔ˜<&¨¼dåTíÝZ.W»úê«£”쥿ںeË›o¼1†snSàuÆûî¾'yÅ·ßz-Ð7yêÉïíM 1Ä9÷Ö{ºãÄÁ½¢Ë9LªvöYgç`"dÇÂÂÞøË“y²Dz:44Ô›Ó#ˆa>P@É”`µøo‰¸« IDATcDQ‡¡4ÆXMÒKÖ^ù~F²-&ndoôÓØêÚæ)©§Ï1E•|×ßˈx«Ì5qö‘Íf³vö,•Š.%™»c.»±O'|çªs÷¸¥® ‚PŸ˜Øi¨Õ×UUUUÞpíuAœ¸’k¤u÷ß{ož={ ¼ñ°ù .”þýɹ¥åxé1o¼£»öý¾lˆô™çþ:\É^ª;–.¸W²È(¼½ÝþG…[€Aa>èUÓ"·Ší b­uäÃ6]þït6nÑçc½ò%м¾¶Åé­çY&Ž»ùé<"rȉ$ÑýjgÏa$Ȯؚ2'¹'=Bßi:V(R2$¸Hé^ú£‹çÏ/E1Dé÷rN~W]~…¡©©©©§Ïš3gNwIÕj¨ìéó=Qš•!ù=ÎÏÏ/?66v¨Òý TŸñ…?Ir‰ó{ºÔè„ù W0Æ46‹éÎØ§DÔñÔ-•P5}éS;.xá³™:³5 Ûâå í‰Èka>"¢!“Ι–zùVÙŒ®´XΗ+›L¦`".¹Ž+v¶¶¦†©ëcÇmQº—þæ·ßÞrâDug!ß^ãv¹¢®¹êjéÉkÝ`ô3úi4š2©Z]Q¾â¿ÎMÇÊÜÎvÉ©¦\0ÿ¸Òý d'N1yÊ”Í2e ã®;m|Â|àuC†ø…[ͦŸ‰±»;;ë¹ýò·Ö«cÒ.œÚÛ}U6:2\.Þ+ï±hÙl]@è¹:güƘVª¦RÑB’[±{ùb‹—Z…¹ÛÛZItKN›:eªNé~ú·Ûí~éùÂûºÜÜ336däôô9‹­ZêzSÉaÖÓgwWñ¦ŸŠåjwÝu7¦òuÓ>ø0žˆìR5Nì£Â-€C˜¼Êf3Ív:U»ˆÑÌ2Öœ|å­—¼±jªÞÐ›Óøˆˆøñ†ö {»;­×ÞÀ÷§F1A]!ÙQœÕjZ*y«ÌŠ]R UQ“Ó’¼Ø%(¨¹ªü8I†¹’ÇŽ±)ÜN¿òñ>Úáp8¢ûº"¢»ïZÚã?%9Ï~⸡§Ïî®’mëD©ë*•º")99Fé~ºÐ°°à™3fʵ‡Ø››+Úø4„ùÀklËRâl=# íèœÖÏœsñk+ëÆ^u»äjYoâD­å5ö-m±×ߥñ³Xc¯{äXÍ<ü›•ÿ®ØM“:76+v®–òâZ¹Z||B”’½ô7oþë É)•}¡¹©)éçõë{´n7.>ŽK]w¶6)>Y³¶¸@râäáEJ÷â+Þ|ûíNÔ&UãŒ/Qºð]ó@1ÆT6‹ébüï$³.öWNõÔOJ®QíªÑÉÉz©ëîööÀž<·»-MnG»äô½¹çïR²_ðˆÅb©V±8QlGçüƒ#¶ÏñóÑ¿[+cu*¡mN”eÿYCÍ&ƒFˆ#¢Q]¹¯¨¦MÉz¼Â·§F]ÿXôÎå—¹Ib….øˆèc@—s¹»—ÝŒ»œ£¡^ò×VoÐ;”î¥?ÉÏÏW|*¡Ÿ¿_sdD䉃—;³zõOÎßèÙóýüd'€:íÍŠ…ùªs÷ˆR×u:]©V«•ý±ûŠC‡ŠŸzòÉ’;²"ÚÛÛF‘ä4‚ƒôÃ÷ßÓiYÛÈ‘±[þòÔ“Ö©Ó¦&vöü¥Kïjxú©'¥JþÍͦs‰heÏ~0Ø }Ý 4ýšE‹ÎJMÛv¼¢¢²£Ã·ß~Ç8"’ž¾'ÒÂn¶ p„ù [l6Sª(Ð" èà˜}ò’GvL»íQɲRÂüµÅœžùäÌ¡–ÑA†4FÔÝé]ö§2KJÚÝõ¥¨˜+îÑÉÕT*vÍ’ªýºbWrª ng»ôd>½nІùÜ.—Kt»ƒ”|ç˜1ɇüüüˆˆǯY¿®L£Ñœ1±ÕÞêqÈP§×ˆHr­²»½Íåés»«½¹Ñ$u=>>¡M©”Ä9ç¼ï¾ô‹æ/ˆ:räÈt’™tÚìhIÉ´YÓgè¾þê«,¹CF£ÑoèС9ÒO ‹c˜z=‚0t™Õj=8ý@Dþrgc5g/£ nîUÓºòL­ªîîIa[™9Ô¬›ED²89"§Š‡Öï(kÛÝ{{Sàø´qL£)”,rþ’ žŒ½ü&Åׂ÷‰n·äŸ³\nSº—þ¢¾¡¡<\u›Ÿ¿_ó'Ÿþ›ÅbsÛnÛqúYQäv»Ý£)z‚ D$˜]’«o{ƒËé°I]Žö½ »n—ËuÁys·|óõ7iD¤õÆ39çÖ–Ý?ñ¹§ŸÞ(wæ÷¿_,Î 0™&{£¼æ€. 0›‰+ˆ˜äG""A­.žÿÊWaã¦wªS1æ¼p„-ó™™Q|¤EŸBDM4:ÞìØvý·uN´&xro HšqT¦$ý¬*#&ÎÆŠ]ÀÜ¢dhÏáp Ú?9ÚÛJ¾oÕO?Ðëõ£N¿~îùçKž··Ú=ž`Ç“¼×ít+æã¢(’[”\<,z˜WÂnýÉÅ .ÜvèС½ðhá·ß™õÂߞϔ*Þ¸xqI~»šÙ ýÀ 2hÿ2ºÎj5]ÆúŠéåÎhýÍû.}s­Ù9¢ÓP)a¦/ÎV~^´ucÌ£µ¸"§ò×wVl¿åûÃÓúÓjÝÓEÍ»>¢;çƒãÇd¿Lù‚N%H®]u:ƒv…2ŦNŸ9sddäT©š³½]úׯåìÁJ\Þ§9{CM-É„¢‡ fP¸^õÜÓOoÌÏÏïÍàû×?_OYýÓO»O/X­V‹Éä_ yG˜zaèP€År5#öu°ÆÐœuÙÛëbõÖ€CuFµPÿHJäæß%MQ16ÌÖœÙÍ‹¾Ê·üt¸N2¨ÓŸøMŒe‚º¼«çÇ^†»¾BT’ÙœN×  ó $"E&ÕMŸžR'W{ï½÷Ϙ0ʈ5ýÒŸG8—;«F§Ï쎶†º&¹Zdd¤ìjô¦¶¦¶öÝwÞ×ûobê¥wÜ^W[{Æ×ÑÄ “ª¥o¡Œ±A»Fza>e³™/æŒÿ‡:Xë7$|ëÂþ0F¥ÕÉNí#"šâ·ëÙÙÃìaFÇkÛ]bÁCë‹?–q4Õîý<}ŽÒü"cŽté VìúA%7™Ï£•Ò¾@¥V«U*¡²7ßa6™Í&ScæÆDtúê[¾ú§Õé߯\9ñôûl‡T*•GAËööö6"’¼W£×ëýÛo¾M¿ä¢‹ŽÜqÛmiR÷Ž?¾ÁÓ÷VŸ¨–¨1š: :{‹ËÑ.»"Ø`Ð÷é `o©­©­Ýºuë4%ß™³gïôüÜÜÂS¯Í›wAŒÜyAàî3 úA;äÙlþ㈠_3’^IDdŽŒÙ²à¥/§*•ìgI˜¿¶ø“Ã]Íò´Ωùõ¬Š« ëÒ<}F_³ÆOR²¢ósc¯Xbíýn@)Z“Yrl›ÝnSº—þ$5mvÛ§Ÿ|Ú+Ͼþ†ª‰(‰ˆèÅ—_ýчîxéoÏoljêt’Û¼ æy<í³¢¢¼H:¬¬µt¼~Ü[œí-²a>N¯ÈªßÞöÁ{ï "?O<¤zøá‡Ë¾Y±â¾ÐÐÐ!‚ ÔŠ¢xƯ-çOD+í|¦ÿÀo™L Ä…µD$, ŠMÞxá+_Më(È75Ô”õðÔ›F`¯t‹¼øžÕ…•« ë¦zúŒþÀ/*.šˆdW`*#ÆÏHR¤!P„Îl“\³ër¹BÝ.—lðÊ×-¾ùfÙ©f=õê˯DqΛýŸ†k¯».5{_Ži÷¾œ}Aµr÷1Æj\xáXOß[ZZÚ$óàF•F«Ìd¾¶v·\M«ÕúÄd¾¯¾úR‘ŸËÓåìÍ™`·Û[O½f6›Ieœ0™<†0üÍfæV±5D$w&pdræùÏ~8‹1éu°Œˆ_7zÈúkGMdDO{©±;w^ûm­°¾m„§Ïè/µF+hôÅ MÄŠ]_£–[§«:VVv\Ñfú‘á11QV›5§7ž{àÀˆ;þpû!Îù©Á=f2™’/^|@î¾ 'í×h4Þr÷çÚ¥®«ÔêOŸÙ]jµšÉÕDQ”œ9¸\.çñŠŠ1½ýžà!CN<ùÌ3“§NÉ;yù}ýÕW¿ùšŽ.¹–™3ŠëíÀw!4DDd±Xw¯&¢H¹3æ°a[ÏöƒéĘdhD+°¶Ç¦GmŸê6õà3&»¢9ãÆ•‡&4;܇û}PXUGõ$¬Øõ9¦°a¹Ú¡‚‚j%{éoî]¶¬µóSžYóÓOãSgÌrdïܹ‘sÞHDÄ9oúþ»ï#äî¹gÙ½=yg^~®äu¿U¹0ŸÎ ;)Õn··+ÕGo)*,*#"]o¿çí÷ßË¿úwW§~úßÿÆ?ÿâKé'¯¯úî;ǩ璒“$’LfÝ2@W ÌÄÓŒÁI~¢Á°sÁ+_M`‚JrÚ˜N%´<6#27РžÖƒVøêºÌÇ2ަrî[ŸQ~‘±òkv™P‰»>Ç?|¨lõ`~~³\m0øÝ5×Lõ÷7ÉNÊë©òò²Ð+/¿bö¨˜¦óÏ9§d☱”ÏX,æW‰è,¹ºÖèwà¢×W%jäT$£ZhxbFT‘Y«žÐƒ6ÿÞS¹õµ³zðŒ~K"·r•B“&äcÅ®ïñ ‘«åç幔쥿aŒ±7Þ~“ˆx¯þÿí÷þ½Ãb±¤ž~ÝÏÏ/šˆHÅ S¯GDFJ~&Ùíö ":êí s#’.*’pS_÷á-Œ³…¾ù¼¯ûå Ì0ˆY,–¹ã/ÊÕ•úØÅ¯­4ëü-’a3³N}â‰Q-*F£<íA伿¡õÅ'òªí“<}Æ@ ³ ÑËÕ¬á1Ar5Ø4FS¹£¹áŒ0_eÅñœsÎc}ÑWñÊ?^µwÏžíååeS•~·N§?üÑÇŸôdš(eüœ^DÒ«UÅ ø1#zúü®2Ɇ‚Km¦é;+ÍEw¯=Ûh4¶NKI™,Us:UD4œˆ·ËåR©Õj"¢¹ç1æ4õN§ÐÎ…aÄèѾîÃ[DÆ¿$"„ù¬ô¤‚ÌæxñÏH>ØÝ2÷™÷[!’§ôjÖôè´È£hO{àDõü\R•Wm÷ô…Ö$;ÅÉÙnð+0Ašhd£Ôu‘óÀÂ#…ƒ~r—J¥R­^¿6Ùf³íQò½‚ T­\õ½Æè'=q´;¾úúË6©ë*¦D­÷“ |y›J«7±©ZqIñ€ÿÚkÿÜb0íD$9-ó“>*%"áÄÉ ‘F£‘]Î9Ç?c€GðMƒÐ!CLnV‘Uæˆ8íŽÇŽL–œ¸§˜cyJä!šyÂãDõ¯/®:p¢5ÁÓg $j«ìZUw[‹SÉ^@9‰d'ï­[»ö˜’½ôWƒÁ¸iÛÖøˆÈÈíJ¼O¥*¾Y¹¢qÄÈüñ¼œÝ{%ŸcŠˆ.óÆó»CP«j¤®=z´÷ÆÚ)$!q´ÁÛÏœtøÎ¥K7µ·µi³²²6‘ý”rËŠoW¤?ýäS³ˆˆBCà O½W¥dÿ9Šs’ útkv!§³íMFLv5îÈs/ËyÖÂT©c$>2-r—Y«Nñôýœ¨ááõŃ%ÈGDDœD¹’«­Í¡d+ œÐñÓ‡Zù¡dmýºuî[ÿp›ÂõO:NŸž¹qÊC<˜ñÕ_L'"Mo¼Çbµî]õã¡¡aa#½ñ¼’’’c§C2Ì>å,ÅW(«õÆGó™Ã •–*ÝŠ×Íœ=Ó+á˓ªV|÷]8¼ùÖ%¹3¦M›À‰ÄqãÆï×kµîm;¶lmnI;y~ê´)§ýH¥RuæSã_˜à/š«Õ¼„»Z®·iÚmJùˆˆî¶iˆ¡GA¾ÚÖ— ª q—S6ÌÇÔ|û(ëðøôÛi_ÿ“Ÿ›ªp;ýcŒýõù¿¥~½bE±ÕfÍñò³ë®»þúŒ¬ì]É¡aa’«Ã=ñÅ?/’«E¤œ3Â[ïé*­Ÿ¹UêzéÑ£:¥{ñ¶`ƒÁ˜ï­çYl¶fúuµ®¿¿âÎÝ»ëî»ï¾]-MMÚŸþylksËoV0_}Í5Ýùýªx|ƒH@€)‰Y®®õ7ï›ûÜG“åêçD[¶Ž´ég{ú~NTÿÀºâªüêÖDOŸ1PuæSiÔXÉ裘J¥Ñú[JÕZíöQÅEEdš—;&vçîÝc^}í»ƒ‚vÉOµìŒZ­._xéÂô¬ÝÙì±'þ’*ò«Q=±ê»ï%'ªµºC† ÅÚ–ÐÈ6©ë•UUJ÷Ò.¼è¢Jo=+?77fÇŽ'ÿ·F£‰úýâų¿Zñí¨§ž}&ãÔ³zƒþЄ‰³V¾¶¦¶EîÙœó3Ç#tÂ|ƒDpp°¿(²/ˆ˜Qò€J¨¼ð•oÂT­ä§(“öðÅ#Æö Çs›åW·Æw~Ô÷ˆÜí’«©Ôú^Y) ýCPÒÄz¹Ú[ÿz«PÉ^’ù Lܾ3kü†ÌǯþÝÕ¡aaYŒ1ÙŸK¢_&ðfŸwþÜôÿ|üÑþÜ‚ƒ¡Ï¿ôRšÕjµz»¿ÆÆÆ†ÒÒ£c¤jCÆN)óöûº" &Ar"\kkKçœ+Ý·ýñ¡Æ0F²!ºîºvÑÕ³Ö®^›NDîS¯§Í™ó›ÏÉ… –Ÿ~oMMäD""Q¼Õ# .ê¾n”át¶½Áˆä‚tîÔû_8f°M”*êTB˽“Âúu%¡ø{+³·”6Nóðþ»ò“ùôóù²ÈsÊ·­—¬ý°jUÐ3}V᎖¨¨¨ð'Ÿy&üäÿ®«­­-).>Q[W×ÖÞæpû›ü4‹U=<:Äl6ۈȦD_ÿüÇ?ö‘ä¤Ò˜ó®ôÚ*ßî•ä/Sò«¨(¯ V´!/³Ùl¶ù .Lÿþ»ïÒ¼ñ1±ðÑG——ŽNNV©TšG|HòŒ •>òÈ„Óï­ªª’\ŸMDÔÜÜÜaè@Â|ƒ€Íb¹Ž]+Wž>73jÊÙiR5FÄœ¾O#0ƒxkŽÔmü2¯&ÕÓû}A{}µì4)½ÙfR²PVHò¤DÆX=çüŒépÍÍM‰¥¥¥eQQQ}ÑÛ@d °ôùÚØO>ùT2ÇÔê£qcúdi@Lâ¹ZNÎÞŠæ#"úë ÏO[»fíáöö¶‘Þzf~nnÌï]#U»óî»ùûûÏ<ýzyy¹Cæq-œs¹@‡°fÀÇ#‰ñ¿ËÕõ&Ûî÷=7K®~AŒmÓ£Æã _A­=óYƒ:ÈGDä¨.o“)¹µ~&³¢Í€¢˜J­öš+W~ç­·+ÚôØÞ½{ Z[Z${afc’ën{›94"œˆÚ¥j™™>±úU§ÓéÿûåÄkêíwfß¹tét©Z^î§ÌmÇ{±%ðqóø0Æs»ÕoÑÁ~© Õó^ú,‚1A%U·êTUs‡[“=}u«+ëµÅ)žÞïKì5e.©ëŒ±z&ø<öqÃæ\(îúú˯ûd%+xîÙ§Ÿ– lñ„+–ŒP´™S0A¥Rë ÅRµM›2û$`Ø’’“Fþýµ×‘\HºÇt:ýá5?ÿ%û "ò—ª©´º#i½*Ô3ª…úËâ<]É_Ï:¾¯¸Þãáý>ÉÙP+ùka°·*Ý ôØ‹®ÕÉÕÊËÊ'*8„‰^ý\MMMmzú†ÉR5¦R—FΚ;QéžN>.Enšªv÷™&# IDATîÝ>÷56tèЈ{v'Ý»ì¾M樧ϱXÌ9/½úÊÎ׬™a0ý::[UUUãv»Â¥jŒñÞö€0€ °˜qbseÊ®³–¿Ñ®Òêôr÷ߘ¼—y´’q{YsÆêÂÚOîõe®Ö¦©ëæ°(É)Zà{"¦ž=ž©T¥2eaùŸÿ$Wƒ~â¾{îÝÇ9I½bο¼ˆ *É œJ =i8‰RµUßW¥p;Š`Œ±;–.¹??/â©gŸÙ>,zØÖ.¬ß ÃÁ³fe¼ûÞû9»öîsÑÅOêÊû26lè ©ÎîFë¿¡îëÀ»l6›•3ö2“©G¥œ³)$qbšìý:uE| aŠ'ïnhwí~fSé,Oîõe¢£½•‹®P©ZàðÒý@ß`‚ Šœ~î‘ÒÌŸ¢¤ê»²²&ÖÔÔÔ*Ýtîøñã•›33%§ò‘ :žpù-­%÷6­ŸÉ¬Ö \íöQ§×~øa•öɧžî‹¶¡R©T‹®¾zꢫ¯&QÅ‚üü½99'J–:šš›ÄÀÀ@Udd¤.""ÒoTܨ0›ÍGDqÝ}Ïšµ«›eJÕuuuÅ=úAÀ †0€á\|–IÇTZÝáY÷<×áÔ¼[Æ&¢nòDNwþpd¨ÈyŸO¦êoZÊ#¢3‚5DDCFO P¸èC WÝšTšùS;±r—sò»õ¦[v}ùí׳û 5èÄ-7.>DD3¥j1ç]V ¨5ýæ×-`dÂñªÙg|Ï)..ŽåœsƘ\ÞÛg‚ Ä'&ÆÄ'&z}åûÖm[%§32⻼ý.\°fÀ‡X,–©Œø™2Ÿuÿ M‚ZsFˆè¤0MQ”IçÉŠ\ñÙÌ£•õínL“ÐT’W#SƒF& U´èSzK`uDü¹úž=»gìÝ»·@Éž sk׬ٓ——7CªÆTêÒÄ«níSùNŠN9Grâ§(Š{÷î9¢t?¾ÄétºÊËʤj" ;•î| Â|>‚1ÆTŒ¿J2ßÛG$mŠœ8{|GϸuLh%y0µu_U˦måÍãº{ß`Ñrd»ÔuA¥.Wë ’žÀwM¸õÏȸK¦¬ºeñM-Š6²Ûí­wÝqgIN³KºæÎc…¤ûBôŒód§Ñ}ñßÏË•ìÅ×lØðóι¿t•e(Û ø„ù|„Õl¾–IN‡b‚pâ¬åÿLîèþ½º,РžÜÝ÷:E^ôxFi·ïLŠs%ƒ>k@…Ò½@ßó:Ô³]®^[S3þËÏ¿ÞʺmÉ­YN§Sr‚¦Æß?gø9 §)ÝSgü‚ÂBTZýa©Ú×ß~mPº_òá‡ÖÉ”ì ™Š6>a>ìOÄŸ“«^xÓAÉbíè G""U7_íúKÆÑV‡[D8¤ŽÚŠ0©ë¶¡±­J÷ýÔ{žFD¹ú£þsp{{{›‚-„íÛ¶çnÎÌœ)Sv¦<𲎓œØ××Â’&•I]?^Q1¦±±±Ié~|ņŸ×‡H]gÄ39çø= =‚0€p8Ú&FáR5•V_0vÑR:{Ƙ ¿Ý}æÍ{+[Fw÷¾ÁDt´µˆ.é©^A 㺞á7f›\ÝápD/¹ùÙé}ÐûêëëëÝuþ$r;m‹ux\œÂmuYܼ+,2%ÝG~°OÑf|ÄÁƒKZZZ¤jœ±5J÷¾a>€.0ÐID÷ÉÕ§ßõT3T†Æ‚ šc£¨î¼·Ý%<»¹´Óà`×xdoÉ|Þ†&N´)Üô#“ïz*‰1&·²“6gfÎÊØ‘£dOð QŠΛ{Xn½®J£)š´ô‰~½^€,ÈlŽ'¢dÊÎ9¼Ø•ç˜u*GwÞ›QÒ°9·ºuTwî¬êóvJNÀRkõ…ý= ½/ñÊ%ÓÕzCž\Ý-ºÃ/š?ÿ0çœ+Ù×`öÙ§ŸîøyݺT¹zøä´M‰ÄzñäK«¥®sÎýÞü×›ÙJ÷3¥oذßÑÞ#UãÄ>UºðM’Áƒ[ §Hæ{ypâ„­Öa±³»òœ‡ØåσV§˜÷ò¶ò™]=?صV…H]·DÅT‘ƒTªé¿ªÞøØ’v"ÒI©¨8>å¾»ïMùﯤ)ÛÝà³cû޼åü)ˆ$C¸:³u÷Ä;ŸÕÛ}¸NGUþîC•²ëj s]ÍUªö¦-cŒ *'"RéŒnµA/šÃ¢yàˆm@L‚Õ6tD˜Îd±ž|NLê‚q›ÿ¾¼žsn=ýÿû«~wßsOoÿP|ÆÞ/¹²˜ˆD•Êõ_E›Ÿ…0Àe±X& Œ.•,2VŸúÇ—“ºú¬üÚÖh"©“‰­œÓ‰{VÝœãó£ DG[‹ÛÙ6Bª6>Eév Ÿ²ÆÄÇFÍœ›^ºiušÜ™ïV®HM“´eñÍ7OW°µA%goΡk- ᜛¤êLªRŸúw$TªÞxeîîü+Þ?^‘³# ½¹q#êÒô¿ªÜÝtxý7§ô©®°FEFM#ŸuAdÀˆÑÇjï?#€X[[36??¿8>>>Ú{? ßTQ^^y¨àÌÚeJ¯©i-U´!ðYc PŒñ¿Ìô¨á³.Ø£3[Óºú¬úvwpE³sK˜¿F6(ĉ*X_T_ÑÜ×ýn§†‚]GˆhŒT-jRZÂí@?6îægß³9ÇÙÜ,ùõBDì™§ž—˜x eúô±âu ÉÏÏ/¼|áB+ç<@æH[Ê/TémA]IwEså±òÿ~¾ t{úHÑíŠ'¢x"™oì]ÄEWX]Éá°º’ÔóùÛDDá2GÙc./ùïç_D÷àuƒÂÃ<”GDi’EÆ_W´ðiN`€þÉb±LaDóäêÅ™?ŒÍ[ñŸ-ÝyæK»Ê’[â>©Zu«3ëÖï¹ò«íòuCMNfÔuNÔ6dT²äÄ>œ˜J­N{ò½ ƨ¦ƒcÆë¯¹6jOvöAÅJKKË/Yp¡NÅ!2Gx¢۲ƒFOòZïøì¼/nž»ý‹›ç†”lY—&º]‘Þz¶ÉõÍDDëÖ­M²Ûí­½øî¯±±±iåŠc¥«üh}}óJe;_†0À$|yGuιm×^œ¾âöÛì 5…ƒþ§¹]l¼æ›‚Ð÷öTnÉ«¶o,¬kÛ´¡¸!ýuEù7®<4¹¢ÙáîÆül­ÔuÑÿˆJ£‘¬Áàe ëŸ é—•×’8çæ+¯¸ÒZPPP¬\g¾+÷À#çÎ9‹»\.Ùïo‘ÓÎɈÿ;¯¬7®=’wä‹Åçìøñ¡ëâ›+M%¢^YÙÛU¢(>þØcY}ÙCwÏÝwíâœÛ¤jŒØœs—Ò=€ïbœó¾îºÁjµNd$fQ712&œ˜|럎Œ:÷òirg8§¶£µö".R‚×Ú|ûÌ .ºÂN¿œ8aãü¿~8»/z‚þoÏ[ÏdÍü)µ£3Œ±Úÿ|üQVîzîÇU?fßuç#8ç¹3C’'oJyàÅ™=}—ËÞÒòó³÷î,Û½y:izúŒË|´¸EŠnll¬V¼13ú’YœÑƾîÃ[8Ñ—Åû¿½¢¯ûå`2Àø¸œºä#"â\²ã_ONûîž…›å¦ôUÔÚw!Èç]mÕeeRA>"¢aSæôé4.è߯-y$Õ9|sGg8ç×_síÐ5«WïRª/_òÒó/l^zÇíÉùÆgx#ÈwlWfÎÇW§Ô—íÞœJÝò9J™ Û-¨ý¶ªuÖLµÎš©Öš7 j¿mLÐíaL(&"§§ý¹Ý®ð—^za‡§÷û²Ûn[’-÷5ˆ^G¼ “ù@“)QT±ýÔ0ß©c5c®úCNò·¦|FC«sS]‹«Çø­²Õo.úæ3¤j—¿ýÓ1ShT¤Ò=ÁÀÁ]nÇÚ{/ÛÛV_;¹“£í>öxöõ7Þ¢Hcçœß}çXõC‡“-1ñ[R3…óè{-E1ãùû7mZ=‹º¾N—3A}@m ©Öø‡¨õ±ÄT†Îïân·³µÌe¯,s·V9ÜŽ† .º¨‹ÿâŽV£))-«Ãt¾ÿ—ŸŸ_œ2mJé$ÊMn‘bæè0™:Læ@D»Ÿd‚|£ÿ"ê0¡Í9ÜûÙ?ç|¹8mgcùÑ"·(V×µ¸ÆôF¯ƒ]uN†[²À„Jù 3L­Òžý·O5ãNŽêžøËãSýÓŸ28þ ;v¬|êÄÉ{; òY£c7õ4Èçj··|¹d^VѦÕiÔ• £2Dº)ê¬JÓ°¹I†!ãÒÔ†à1] ò1¦Riý†ê,1)ưi©¦asGû=§FkÝÄÝnêä³Áát{ðÁ?níÒ»‰… /®&é qbGzÂ|DP1‚ˆ®‘,2¡ìŠÒ£Òyõ€ ÖíìYm u“W.]’þâÃYÜíêZXº¥µ´ Têº%,ªPé^``RŒ~g¿øßp•ΘßÙÑO>þ$5uæÌ¬ººº:Eš`þýî»ÛçÌšíW[[3®£s¡ãg¤Ï~â= òµÖ¨þôÚ٥͕Ǧvz˜ …[Âsôa†!ãÒ˜Ú ù}ÂJ7Do5Ó4ìœñ~©GÕ†€ "rÈÿàý÷ÇWVV" FDï¿ÿÞöã“dʵDìEE€AkvÀçÄÆ^ ãF½…Ün³S,ä&+D ã‚™73Î,"çfƸ™˜`#Î-ÄÈÌ9¤xÿŠ }Ý¿œ«ùoœèRµa3ÏÿfÖ²¿-$"r·µÙ7<»4ûxÎöêBh›i4…±7É]#\³%ñªÛzô{Ó^w¢æ‹›Ï¯s;ÚFvr´No±Ok‹›AŒuuo‰®Öc-e›j¸è+UŸ8qbæºõ:ù¹òmÕÕÕµqq#¢[ ‘ª3â·×Ö7½¡t_Ð5X³ Â|ЯDM¿ÒÀjÛmL«Ñ“ÛaP3µ“h#"gÜF$؈s=q2pF6"nˆÙ8‘þÿ?z^.Š)E¹+·yñ‡ã5fÑí3C¥V«{½¹~ˆsÎßï½mÏ=õÌ0·èïè,c¬nÊ=O—„L˜Ù•Ÿ[Yíõuÿ½ñìšÎ‚|‚Úo›_øôQL¥ èÉû<åv´m)K%"­Tù«¯¾É=ëì³mtÜØ1ÛJJЧIWyv}CóιôuèsóÀ@7(ÿ‚zÇÿ‚x*nS3µDQÏ2ЩA<‘ÛN†ðˆ1½ÀÉpJ/HM¤!5#]DL ‘ÄÿgDôk™±_ÿ‹·"Éj昗åuÜ庞˜d‚âÆ®ÓýÏø?v#&ÌLºê“ÍméÏÜY±gët’˜wªöÚò”Ë/m štÞ†Q7ü9EÐè< EUï^ß$Sj N+S¥1ú›ÏûÇ׉ËoÚÜt¬hFgÇ¿úâ‹ÔïV¬(zþ¥—êæ/˜?A‘&û‰;väÝvë­Îúºú”ÎÎ2µúèì'ÞqY¢bzä].ç7·_Xâv´É>‡µél#³ôÖ¸>|§Òú U‚2\öêT©òÕ‹®4•”¶F?Å›ëcÏ=÷lfII±Ü¯Dv'‚|Л0™ˆèÌ Þo¦áq¦'Æ §ñþ×1…PVºöcîâ! z¾aƒ«¯9cŒY,¦ŒHjš›}áÛkvû…Nïè{·îßðÌ]FÑÑÓ¥w ªcÑ n)ˆ˜w}*1A±¾bûýó²Íug¨ –€Ý‹>Êß=àœï|í±Œòé©t2ÓÜÉÑÑÑ}öé¨ÐÐÐ'Ô tåeå•·ÞrsQ^nÞTêÂÏ1$|Ûì'ÞIÐý%ƒÒݱîñÛ2JweJ…ã~ÅNø…N­T“zú.oànWcÓÑ5"$UŸ|Öœ´0ιdˆ‘½Y[ßx›Ò}@÷`2 tó p§†ðH¥Ò«Ef8c-­Èmĸž3H¬¥ "]ßþ(ú…²¢ýßFöuRl6Óâìg©š10äÛKßY{IWž#ºœÎo>•uxÝ×㈘±+÷¨tºÜØk špÖÔîô<˜q·ÛµùŽíDtF dØôsÒÏzøÕ4å»_S’±jÇÞwÿ–@œ›ºrž1Ötñ%ïzô/™`6›Í½ÝŸ’ 9úðƒïÚ™5‰ˆºò½Í>rÁﲯºÍ+aµãûwæýøð £Hnú)£2ÿÈ9¢ 6ö«æÎúÂ-öº<¹ 8ûíww]~Å“mªÔÖÖÖ%ÄÅ69œÎ¡2GéôÆñÇoQ´1è6„ù` C˜ INÃã\OŒº²–– aX—í-Û‹ö;­¯›b³š¿$¢Ë¤jç>ùΪ¤)ó»ó¼–e•ëÿr{IcYÑ”®Þ£µíH¼ãy‹ÿЄ¸î¼k0j8´7wß‹·&JÕÒî>{xêƒjå)ôž–ª²c—/©s¶6%wõF¬yÊ´)Ù/¾ür\hXXHoö×ÛvîØ‘÷èòåµN%b]ú,Tü÷ÏXþšÑÓ¥)¥]ñÙu³³íõ5Ò¿¯•ûGÎq÷· ßIÍ¥ë³DWÛd©c¬në¶qqqÔîKI.—Ë=:1!§ªªRnjªKälfCCÃvE ÌB@ˆš~¥AhnÕs‡ÊÐýµ´d#"=YÕDŒÔŒˆ‰$þòpöëf@Έˆ±“»':5~Û•ýŠÐU¬¬¯;dŒ R_,USiõ[C’¦œÛÝgú ‰¹èµ!G·­Ý»éåGL]Y½ëh¨ž²ç™Ý~‘£6ÅßòD„!$zxwß;XTgÿ|B¦$FLš5RÑfÀ§ùGDžÿÏ•!Y¯<’~|ÏÖY$7ÿömÛgÏL™nŸ0qâÆ'Ÿy:*..nÀü~njjjúø?å¼ùæ¿,M¿®¬íüÓµDŸuáÎ17Ü7“ *¯­¯Êß“/ä#jñŸÞ"¨±ÞzŸ·ù…¥D6•nh$¢3¦5rÎm©©³j 7™Íæ.M€h8çüü¹çn©ªªœÕÁ¡§äE ̃NW×ÒrF6ƘžsÑpúZZ5‘ŽHýëï S‚xD¿†ðˆˆÿ6ˆ‡˜ýçÇúº)n·f —ü^=fÑmD¤õôÙC§;vÑÇsœ;Þ~vÓá5_Ž'âg¬…=ªåXÁÌ]-rúGŽLO¸å© ]Ht’§ï÷Uu¶Jþš¨µúÃZ?Ó(¥û߯T*Í”eM«;¼ÿÀ–ç–1w»]r*¤Cö®]³çÏ=Ÿt:Ý‘]tìÞe÷҆†ö»i}mmmöU+¿Ûÿî»ï8,Ljftç~cPÈ®”‡_ ÷ Oõjcœó¬÷^¨"¢x©²Ö62[ÐÚäCbýSÃt‰[ÛksS¤êímm#RR¦díÛ—;QAéþzÛe—^²q×®]²_ŒøšºÆæ§”ì 7¬Ù…Er--‘ˆé9 §¯¥ˆÙˆHω ¿œ£`êÂä"ŒØƒEû¿ù[_wq*ƘÆj5•§0‰ZÉï¾Ümb‚àwµVW_÷—?6+L¡®}tš#GlJ¸åi£&$zª7úð›oŸ~œ‹bèéׇÄÛ¸àùg÷EO08p·Û¹çÝç·”fþ8‘ˆû{ðwPPàÞ˯¼ªåâK.‰Šíí»‚sÎæå}ÿýª²X¥.).KDÆî>G¥3æ_ò=|JšÜúT¸ŽöœÏßÎÚ÷õ»anGû™·2Ÿ;‚ˆ ˆ\kù–LW{lð0!>acææ-ÓU*•Ïü‹@‹oü}ú7ß|ÖÁ‘Ü„)õõõuJõ=‡5»0Ð!ÌŠø_O«Ñ“Ûa8c-- 6â\Oœ ¬¥µõí|'º¦xÿ·Ÿôu§²ZMW1bŸIÕ¢¦Ìù&õáWzûÇ÷ï<ù·ûíMµãºq›Ó9bÓè[ŸTCbfÑ€¯ô†¶ÇJw.¿gü{©Zøøé«Ïzô_s•îéTN{Kë®^Ê>¼ö«8Ån† y«_T|öðKo·Z¦&õN‡}£|Ã[ ÿû¢d¨å¢—?;829Véž:ãnok-Éü1§dý·îæc%qœÄ ¾îéT‚J}Ì4tDQôÙ—é#§Ÿ•¬Òh{umw]ñ¡â•Ë®EG{LŽ‹ÑaAÛk[üšZì]]ñÛ"¨õ¹*cH«Ö?2L¥µŒèó©}œ»\Žú#bsE¥£­ZÅ-œ‹Ñ]½;ezJæ§Ÿ~>Îb±˜{³Íîjk³ÛçÏ›·3{wv‡¡D"r£[ëêÿ­HcЫæ€a>ð ÿ j5zr; ÿ[|Z ×sb†“kƒO ‘®oÐW8ñï‹÷¯¸°¯û8)àÿØ»ïè8Ê« à÷Ùª¶ê½¬šÕV²e˽›î¢b@ô0%ºi J HèÆ€1L ¸ƒ n’«l«÷º’¶Íûýaà3ö¬­²ÚÙ•žß9óÎÎÿ† j¿ AïõD{¿qC÷ѽ§tßcŒµ]ûaI1vJÇ>Â9o?ºÿбŸ46—ìÐõµ6&rIrg'M.ªÕ•~‘ñµa¦ò¸™ ãýcâÜuòöÊÃå«o¿Ð—Ò)]QO"£B·þþÖâ°ô„¨”Ú¦öº¢{žRI>”?«>&¨*E•O«¨ vúà•ÆÉD}1¦ÊýpŠ;Ìv[gµ£¯£]êo·Øí=²÷sÉOÄ|†shAþÍ£e¿üÕ¯f’LwRwûôÓOöþìÚkü-Ë™F0÷£KÛÛ»Ö¸%Œ8ó€·C1À Nü“‚@ÎtĘþıÁ'ê¾ÿ7xNÏWì[ýs¥cIuìøß§ŸÐúî¸ø?ò”Èu:’ÝnÛ÷Á¿¶—®|!Úa·‡p‹>2~‡1ÿçúÜy¹ÞZô¶ù¶¹G¸Í’ròvÿˆØo/zé³iJd.sSmMË’ú¶Ã»ÌÕý­-¶þÞ0’4ô¢­~Q­®Wùø·û†Eõ§š($}|@и£Æ/ Еùª­òpŇ·_ès¦B>­Zuä‰Û/ï›>>5ûÄí‡+ëŽ^ûÈ ¢Õî0º(’ĉ5 ‚ØÆ˜º[Pi-$¨$b"g¢š‰1µðÃÎ\ê—ˆˆ8·sæ—l‚äè×pÉîCÜȹ#d¸{ºk廫ürs'*Ò‰ôرªºK/½¤âÀþý3èÌ?[˜@Emm]›Ü‘ Ü#ÑT4ˆ^P:‡Ë0ú²¢äƒ”Žîƒb>sVHÄtœIúÇÿ´ 눘žˆ"ˆH8ÃiÀ…ñ_——~ø{¥súß̉ýSnmâÕw¬Ï,úÙ\wg(.9?zckÉ;ÿ ´ö™‡Ô=‰ªcá“Ï-O(¸!MåêŒ#ÅÒÕÚ²ýÞE!$S<’¹äª So¼ï”Ž}ÞL²Û,–޶öþ®ö.K{Soo{k?“8·tw8¸ÃÁ‰ˆ4~¢ R3Aë£Òë´¡¾ÚC€ÎâQ£|;ª+Ž}pK¾–øi ùlÌ¿ùá‹fªT¢ZnsŸ¥çºß¾°çhuÓ@ ÉF3½ãÉ'ÿª9ÿü…ãÝq†††æ_ÞvËþ/¾øb éÏ´?#ÚÊDûÅ­­½Õnˆ0`(æð@²œëˆ1ý‰cƒO,$"?þv…‘¨ì½ð"œ®«Ø·ú¥c|ÈNíâÆXûeooÕŠ݈wVr…†’­w¼ü—®ŽÊ#ˆH3„CH*ßÞˆ‹;cν<]~¦±—ŠªûúÝoÊßyrºÜÚ⿾U–š£H—*8={ŸÙüæ•sêÖþSºjþ@YãSw]Ù83g\Î@ŽYr¤úЃÿXÕQ×Ü–GcüµØßßoßò{ïï¼öºëÆûúúúºúøß}·ãð}Ë—7nß±}²\G[YŒ?ÛÑÑsçÜêê<Ã…b>€Qêä‚@’$ŽûÉc:“þ‡±Áœ(Œ©”¾îÀ?·¼äÃ/”Î& tPn-"kÒús~ÿŠÇvås¦¿­©iÛ¿?plëég_yµ¯aoøôE1ç^™¥ wiHØóøºî=¥ûc¬íÚK‚¼ut0Àh÷þM ¿é¬«’-Ä%"òÓkö­|ì—¡áA†A?5·v5¾¸fÝÁÿmÝØÙÓ›AC+lV'êp‘ÜŽ·÷ª«¯¶.[vCVpppÐPuðàÁÊ+þYõîÊ•áf³9c7mæÄoíèè^9ÔsŒ4ó€S'’(êTÓÿ06øä‚@ƘŽsIÿÿcƒ)ˆˆBÈ Š82+KÞ? tŽÀÀ€?0¢äÖÎùë"2'¦hÁ£8ì6kÙ'oïÚóîŠ[O×pî‡C¼'bV‘9zÁÅYj¿À`—…†Í·Í=Âm–S:{ùGÄ~{ÑKŸÚi·ïƒW7o{ùñ™ÎÖ}tš}ÿýž]àpÏÕgµönÜqèК;{ö>Ñg±¦’çâµë4ê²ä¸ˆÚ9ÒÅùS3âB açüüÏf‡äˆváy­¶*!>¾1/o²eþ¼ù~YÙÙ!†ŸÓÛÛÛ;êëëÛ<ض}Û6óºõ_‰G·Ùlqƒ=!'þ¦ÃA·www7»ð~¸Šù`DýX¨QëÈaÕÿ86ø¤‚@b\ljé|BA`0iG2£Êb5”•}Ò5’ç8Ƙh¨$â§)*uååï~gtª‘ѸoçÝÿy²¥¹¬4‡87 ãPÁǯ48}JcÔüK" )9ÙĘಠdílmݶ|Q0Éæd,¹rý´ï÷ºŽŠ£]_g[Û;WÍ!ιlA°Nsà¿Oßà«Îs”SæÞŽ=«*·í«èÚ{¤š«o î鳉Èå£heH*Q¬ 4ø4¦Ä„›3“c…ÜÔ„ÀÜtc’N«>åüŸm)ùîÁçVN$çŇ¹ý¹w€:gW¶uv~¬t€@1x¼¸Åz¡§WÇ­¢þ‡±Á?)äLGŒéO|RA îûËé®(]à¶;ãDPÿâl½ÜšqÆyëgÝóø¨+sX-–£Ÿ¿·¯tÍ«BosC6‰Ã: š|#ûÂg-#§/É}ü†ÝMk êÖ¯ú¦ü­'dÇt.yòíáã²Ç¹# ܧ\³¾¡d‡ìóªÀ„Æÿv…… u4øHœKõÍ Õ­­5mæÊºKMsohîPwô˜u ›UÇ%l‡'.H÷eŒõ ŒÙDYDQ°‰¢hÓ¨D{€¿¾/4ÀÏB1!AªØÈ}ldp@bTH¬V£ÔèÜk^±qßÑšÙrk‚¨ª1][WúÁ«±’ÃîÊ~.À¿hïè>Wé…b>RSj¹Î@G€M *;rˆ<@àL(/]½Jé|AA†g‰ó[äÖ–<ýA¥!>ÙèæHnÕÛR߸÷ç–oø$J²Z]pH»Ê7`Hö¬¶°©ç…ÆMJc¢Jí‚ãžbï“7nè*Û;çä팱¶kVï d‚੫ƤκcÕïßtAÉ—žºëŠ’Ù¹éãÝË“™û,=gÝüÇ»CŠ•[wÞEë§Ýôà´=+ŸßVúþ«Ñ«%ÙÝx¹½£ëz¥C ŠùÆ ÿFyòšFﻯøÍo²”È¥”Æ=[Kw½ùL[kY‰ÉÙÌÁã½jß ƒþãÆ·‡åží<~Vº¨Õ»d|æ–Ûæ•I¶þÔ“·ûEÄn½ø¥Ï¦ºâà:k¼n}ýÞ­²]ù&f׿ðàõ£®ª+|?nw’ì"c]—üûk«OPX(qÎË7¬Ý¹ûíg­5•¹ìxw\e0öÛöö·;?À ¡˜@aAAþ󉳯äÖÆ-¼lÔî?¥ëÛXÀ%‡£vûúû?üOGÓáÝãÈ!…»ððvQçSæŸÑšw–6${f¸60¼K¬ˆˆgŒ1g”DD¾ÎbënoÛzÏADÄN^ËXråúi7Þ¢ bï3›_»dªƒ8?e¼:c¬í«çPùûê½î©ýê‰m­SäÖÂÓr6.zâ­ŸŒâµtwvZûv页ï©{šj²‰˜«²ˆÝ‘¸)skÂÆ™4Û_~r†Ü>œèÆŽŽ®]uN€‘¦R:ÀXǹPÌH¶°Ú‘}áõãÜÇS0Ac§.0ÅN]@œKRÞoXóz{ÃÞoã$‡=f˜‡W9ú{3:—Ñyø;:JD$ ¾j¿àj¿¸´žÀôÉZøB²Zâi ŸGëViµµÁ ãšã§ž%gŸ—à—BD)å_}´ÃÙÍDΪ‡vO”b>1ÆTÿ¥rkúÀà=úàð‰îÎ䉄¨ 3Ò¢&¯ÁiÜ·óÀþÕ/76ìÝë°öŸRd7$)ÂÖÙÑÞÙBí¥›Øjõ~>‘ƺ€äœòÀôI‘~ ™±jÿ `"¢æíŸ[åó²¶ÐÔl×ä—9ðñ›‚“%ó¯.=/×­a¼1:4~VnÚºM»Í“YVoùçïžó›çNOD$¨Tê˜ÜÙ1¹ÿ_KÉ%‡£·µ©¡«¶²ÅÒÓeíël³rIâj¯Jåã£ò ×ûGŇë ÁÁD”þý??Ñ^{´ÏY^.J5ƒ¾“ B1€‚ ßyD$;>vÜù—öº7÷ˆÈš˜‘51ƒˆ¨·¥¾áèÿ>åúånËÒ×Þ¬—ÛΉ0b¼Šù°€ˆ‡È­%Ÿu!FìSW·}s›Yœ1c‰oÄŒ%DDÄGwÕþò¶ÝëêÛ|#ô7ÕÄK[Œ«ÏÝ×Ñ:qçkOÓÎמ–4zßÒØÉsš3_žž›NŒ1WŸά£¢Lv«F-VùûêÝÇ›ùè4>SMÉÛ¿-92÷ä5{oZwý±cþQññîÈbéî6ÈmˆcÄ.xó(DdüB.¿dO[xY¦{ÓŒœS_}Kßn«UšyòE1 );) );ÉH·‘µ£©±µtsE뮯-æÚ£vsGœ³ñ¹C Xû̦ò ŸRù†OIÅÚHSÞáì‹o ΙjBa€ût7Èc ð÷i%"ó ÒM.ˆü¶äˆÜ;üÙ»•“®½Ë-Å|»5\n;'†Î|àuP̠Ƙ`0ø/‘«äÒBJôÁ¹n5 Øí¼ª¶©×Æ9Mèm4áQ³Š"¢fý¸ÍÚÕÞÑ]YRÓU¶»«»²”z›Žì½]qÄyÀpòIGLÝž­1u{¶’ еYyeã/º18*wZ6¡°`9¬VQn»ŸNcuw–Ñ ;%.M£+¬6Ç)…Õßm']{׈g°õõô8}^f(æïƒb>ûûO•ˆ"åÖ’ä÷¸;ÏhÐÛkÿ®©Ý’LDÃ=–& (0$gN`HΜŸlïo­«ë>ZRWõéKöþæšiÃ9‡äpÄÔïÝS¿w+ *uUÂÔ³ŽN¼úWÉÑñ à ²$‡]¶`– ‚“&©p&É15*ëN)æëª«ŽpÇù{êkšˆÈOnsŒÙï#(`,â8Yr¤-¾2Í­a¼oiï_ßÔn@.(ä;]HttØ”óòDÞîÊãJv[BÅæµ Þ»é‚ø7.™¶s×ëÏ®·õv›]y€±Nãëo“ÛÞÝÛ¯uw–ÑbjvŠìv‡ÕåŽówÔUµ;[9:ó€÷A1€$’/æÓú•ø‡…»;·â™kû¾íéuÌ%â²#4GBo}y’ÜvQï[â³…ëâ¡™µ·{âîwþ9÷õK§ÓïºtcӥÈ ßó pÈmïêé vw–Ñ">*ÄI!$÷µ÷™G¼ ¹£úhŸ³5.J(毃1»n& ”.·–8{a§»óx+«M:VßÜoáœOwçy{«ª¸ä…›0mIgÒ¼KgqÉáh9ºëpí¶ÏšÛ*÷FIV‹lñßiqîÛ|¸döÇ÷^Ej½OiVÑÏÚ³/¼n²J£Ó ûNŒAA ©bõ¶u§l·ZíÆns—¿¯.Àý©¼›^«uúY‚ññÅÝէ钪Ƙ]ð:èÌàf‚ó»”¶øòÁ}A=½öïêšú8ç©î>wó¶O«œ­EæÌ51AÃRóÆM¸âÁ™ |'iú­ÏVEM˜·NTk圶¾^Óî7ÿ1ûõ‹òú¿|ôçÌÍõ CŒ0fÅOžç¬ë©°fÃwè€9Ýf«³5&ªGü¢¡î¦gY»ÛÚÚP^ùÜŒ1Ê'™~E*­þ°TÜ8÷'ò*¼¥½CO¯4‹ˆÜ6V÷D-{¾–}íd‚Xí'·æ›Ut{BVÑídn©=V¹ñ½ò¦}›#6‹l‡Fg8çÕ;6Ì©¾îl›!.iÓ¬[ ÏšdÊýkBÆ™’‰±âÜïäµ7ÖnQ_qÁL%byµ’#Çúå¶3AhVi´a#}~skƒl§RN„»àr±KCDµ` ‡UO¢¨cÄ4¢]òU:ÀpI*ÁV±÷ýõJç "JÎ+6P¯u²Ò9\… Bsyéû{”ÎÊ2fæOIV:‡«˜Õ}ß4ìù̬tŽÓIM]àÐhÆ+ÃU¸(5Wì]3¤ áôré‚H²ëH`z¼¿ƒÑB/h6”–®tzñ€+%çìV)äÇ×S|”ÎîãP fNÜJï#AÕ¯Ò­Gw¬Ä…Öð£”ñKÃm\ðS‘Ý×NL£’˜ž$ SÈÆ‡À\º˜ÝÁ¹JÓ! K„*ºiÇŽ6¥³y:󸑟Ÿ_„Z%ÈŽ…ž4«žˆPÌç'ÞW×dÙi³9æ*‚ó¾Æš¹%ÿHcÉóÈ74&>«è—ñYE¿¤îºòŠŠõoÕ5Þ•Ì%{ä ’¨;«Ëg}|ßÕ¤ Þ3ýæyâìó' âöcŽ ªT†ØÄ½Õå3N^kjíÊÛ_^[–™ãönŸÞlkéQ¹íjÿz"ñb¾þ®® 'KN;¨œNBÎEIŒÛÇ3‰§3eqN)œ(†…«EÒ$1HâDÄIBŸ{ $©•ˆB•ŽAD$õYÓI /”Îá:üC"*T:(‹ âãñJçpä›ID”Îq:f"jàÄk‰ØAF´Ÿ1¶OT«w•í\Ù¬`\!‰EF®â¹LâéʉˆRˆ(†ˆÂ‰H-D ÄI"Žù¡c#FLâD‚@L²“ ‰Ôjo䉦Â&"j"FG‰Ø.IûöUfªKøÊ•¥s{ó¸‘J%,&'oQ2]áæ8^ÃnçuµM}œsE[gõ6TT—ŒrkY³ýÖÓ?:)1ç²9çRÓov–¯{»ÏÜX•KÄ|嚥«müºÇî¢-Ïüº4oÙòδs.œAŒ±3ß`ìÉ^zjÓÓÉ-±_>ñZ÷ÿXÎ?ÒÑcnkéèÉ–[ Ï˜ÐæŽ «%J~…U¸ãüàý’Ƥq‰ÎçœÍbD³¢ã…õŒÿ¾‹2ž¼GêÄâ0‡Õzçl.1>‹H=Žˆ¾o‡wv“—ï£éÏ%Æç£ÙD”ÃH9^Sa`ôD”Ȉ%Ñ,""Î9Ù­VJ4àD›ˆóõ¦ý´ºt¥[¾K×JÌ(2r‘/dD³‰h6‰Ã8áçFDDAœ²‰xáñŸæ8÷[»M…ß0b›¸äø¢òÀ¤mœ?,)œW(æp#Æ©@îsTua¹iîOäùzûl{šÛ¬±œ(Cé,­{7Õ‘Qn-"gVòPË"2gLŒÈœAöþžÎ£_½½¾fçç±Üfð1­}½¦-Ïà_-û{¸WjêB­C£Î—v#ãÂYÄp]4¸]»‘ˆn4fíªK4¾& ªªö®õ¿yDqÀX`±ôžKÇ[KŸ"mñå~nŽãÑ8§þúæÞï¬6îQ…h––Ù‚:_`¹¤˜ïD¡i“LJ¦M&kWKã_<Ð|h[q<ÛÚ­–Œ O=@[_zb÷üå«£ÆOËru>o”{Ù/¦í[ýï#kŠÌ²ú×Ï¿›SßÚ¾ùº‚¹#2Ö»º±­fÃÎǶﯰ®ª÷mí0Ç8$G4ùžá¦?½f_r\TkѼ<ý9Ó²²µµn$2žÉ­ý{çòÏÏÑÙS÷ ¢(۱ϕÚ+v9_E1—hZI¤¾ˆn#¢°Ïd2kzÉz©1‹”ˆÆ}??1Ƴ s¢û¢é˜YBdœÎæŒml(OÌ.ø;õt­¨¨øº_é`cQjê»Ný3Òjî#¢H4¾MDËÉ~ÑTð‰@ÂïÊK?ئt¨‘‚b>w‘äGìcÝqSÈv8‹ìv©º®±¿O">"…4Ã!õukå¶ëýƒGôÄ& 4büe÷GH6koÅ×om¨úö£8ÉaOÈm-ÝmÖ>t=ù…Çn=ë7φ'¤&dVO'C‘&È IDAT¨Tês]aýôþkD$Êì¢yîÝ/g~´q×7OÞqyTRL¸q¨çr8$ûîÃU‡ÿ·}ó7{Ë4µíI—b‰(v‡ÓöôY³ö®¢=‡«è‘Þ3'D…mùùÅ ôgMΚÀÜ4VûõO6oÙw´ÆY¡µeæ¯~;Î9Ú*Y­ ‚Å|c\²)?^bÂ#Dê+‰È#:ZÀЙLÅ3Yo!¢{ˆ(Jé<Þˆ1Æ3ó/J0ü†sÂoRàÁXqúùî2š ÿ\Öù:õ¹Gr^±Aê·ÞGZÍ-ÄÉ_é<NŒØbN|q¢©ðsâô›Š}«·*ÊÕPÌàŒ11Ðà¿HnÍ“P*ªÔÓÝÉõZ¤=M-ý1D\v¦Ò$»Cöu“it’;Î/¨5>Éç^3'霫¥º_n=¼ö_kÆ@nÛÓT3õÃ[ mÑã§m˜wßßrµ~þxcV¤)/3e~þº#_¯™çlŸê†ÖéÅËŸ‘¢Ã·Îd[<775<Èál«Ín9TU_¹}_yËŽýöÃUõ†ŽžÞT"ʉû@D¾UõÍ3îûû;¤Q‹ùóòŽý|éY9}Ð>ÛRòÝSo®è¬j02kâVß°è9#uþuTW8‹ÑÒÚÚzš®}0šÅäåûhú…Û8 1Nèz 0 $gž- ô4Üg+€Q/Ù”?Ù˜Uð$'š[àEâÑ?ŒÍ†_&eüº|߇ï*h´bìQ!1kוœè1"rú;€:—“˜]´Jbâ}£iü.ŠùÜ 0Ðo&q •[K>ëB|t"¢öNëÆ®Û4òà:’d—ëâE¢Z¶a߈aŒ 1“Ι3éj,Ù¸ÿÀÇ+¸½¯g ctÕu{¾óÖåÓ[LE×mŸtÍís™ ÈÞ'€Ñnöœ×Zqps{åáÓuêš;¦>·êôܪÿcÔ¥SktZYˆ[lvM¿Õæï°;B8‘¥}ÿ[YmŽÄU_lM\õÅÖ¾qq‘›î¸raÐä¬D—ŽÖ~ùÃõ›Ÿ{÷Ë#ö:#†"'åøAÅö3 ¯©Ü¿ú¥Óx Æ3šòI$<ƈ¼¶È `ÎÙ]ff›Ÿœ¹äÒ£û?*S:Ð`¡˜`¤Iò]ùˆˆÒ_žäÎ(ž¢»Çömk§5‡ˆù(e0 ã¦ØzªNÙÎmÖä¶²ïJ‚S'e+ëŒ1!afÑ̸iK¬‡×¾´¾fûgˆsÃioäp„óÏ߆—¼ÿò®…ùwœoHd¨›âx„ø)óÇ_±r[ÿÎ×þ¶áÀGo¤H‡+¿ðrøøøÍ2e5œ}ö9 Fg™L‰Œ1§Å³Œ1‰óS›ó]¿ì†MÑÑÑÚV¬Ð466dÓ :mK^QÓ^QÓDoþÍÉ˃¾¿†¸¤-EÏ®žÊÁ­ÝI;Žm&'Å|Œ¡3ßh—’]çÐiWç•ÎÃÇŠ‹Eã~ËĄەÎà­c̘Uð#öa2ŒŒ(”ú¯ÑTô‡ª}«Ãå¾P‡¥§ø³ ^'NùJgpÎ'J‚øÑTô³ÊÒÞS:Î`àÅ`„qb²Å|*­þ°d\¬»ó(­¥Ýº¥µÓ:•Ž·Âö*Q³—&8[;¸öåwfATiÒݱ(èL¼¦êÀÙlâb:þâpŠôÅW„»9Žbœ·U×÷î÷æB¾Ä-¼¾ßÉ’êàšçÊÜfcÓÓæÞûŸœ¤—obŒ±“WWmåäw®žg;ºnÍ>wäðT>!IsåM¾æÎÙóî}ržF«o‘Û/2*ÊáëëëëŠs2&SÉGD6û© ût:îÎ;ïžYUU½~ÃÆŠ9sçlfŒu¸"Ç ÌÉs­»bå¶(ã¬óíŠÖ^]î¬`çüŒHÁû$e]%1þ¡`THÌÉO'‡°PÈ0d 9%98_G(ä ""N4‰H½%!ç¢$¥³xŠÄì¢yL’>C!À.±k5kÓÓ ü•r&(æAŒÑÙ킪><=7ÝÝy”`·KÕ5õ}]‰ŸÒÍÎEÍ*šDLh”[k«Ø›Å%‡ÍÝ™…1–4·xÖì»_¢¿!"~ºÝ¹$…lxòþÌϺn³ÃjuÖ `Œq6!—Ÿ~tîàÈŽÙµÙì§}ÌæäŒOüðÃÿά«oÐ&%% {\¶¨ÑÍXråúËßþÖ>çžÇæ‰jv¸Ç®žÆZÙYŒ8FìŽBFSá/8ç¯2t•ŒY¹$± Œ(Vé,Þ*!+?K쉊–~*Qì_'g.IU:ˆÒ’³—.&Î?%"/Zp³y•ð¿ØŒ¥!J9󌘘ât–ÜZpjÖbÌ•EÉj•ŽÔ6ö«8çF¥³¸ UêÐ óÈ­qÎCK7íuw¦¡Ðø§ÜüÔôœK—ïDUÕvgu{¶Î|ëÊYÕåun à…8ç§-´ A`²ùì2ùäètzý¼ù Ý&œ1¡Õ/¸%¿gê²û÷dæ_1Þ1Æ&ÈÙµZ\08a§¯5Ón~èÛÈŒÜðÞö³ÆÇO£õ×…‰]yìÕHvKŸ™KöùU^áÞ40’ŒYW0ÆžQ:¸FŠ© E"ö9Âï‰Ü%9sIª$ˆŸ‘à!b˜CøÂ˜yáìÊýïÕ+ÆŒ¦üŒ„wˆH­tÏÆ'[TìÃÄÄù +*¾îW:ÍÉÐá`¤pù»DdŽŸ~ŽÉ­Yܬ»×¾­¡µ/…øèü‚V¡ Û%·ÖU{D¶c”'TZýÄk?#çÒå»™ ÖžvgÎý¶¾øÇñkúÙÉnw¸)"€WàÜy1ì`1b²¯væ#"ÊÉÉqúì°öIAIéI1“fe‡eLHó‹ˆ5:/4o¯<\KN‹Ž1fw´0f.dŒ½Bø¼ 0*¤Ž¿0ÖAì+"òºïL#òS: €—Y’xÀò¥Cœ Å|# 9¹Ú8qþb¯, ˆ–ËúöNëLcÜÃóÎw²ÄÚ+J}¿{å!úæÙÛÚ›låœxÝ˨´>þSnúë촅˶c]§Û×ÒÝnZù³³¬õ%Û1Ú€ˆ\ù`gŒdÇì¦3ŸN§×‹¢ªAn­»þØ“)«¥¬ÔY«sG{»¹Ì­aÀåâLÅÁ\>""ƒÒY`øcLÕe}…ˆOV: €·bìQAÓ/¾Éˆò”Îà¥Î7î·þ^é#)nF±^T «‰(Fé,ÞˆsvW’©à¥sœÅ|#@œŽØ•RϽ$Í­aܤµÓ²¾ÇlŸKNÇŽ.~ ÉDtÚªssuО·ÿÄ6=µ¬³v×—’äxŽÒâ¦.ž2ý—Ïõ¨t>ûO·ŸäpD¬}àÚèÒ÷ÿ½Û]ÙÆA³k³ÙU0èïï×,·ÝÜÒ J.¥µUvV,^Å9wVè^€±GYßdD)Jg×HÌʈ.Q:€7KÈÚù/P:€—[žd*¼Hé#Ð 'öl²)ßcG(æ œçËmVûøí÷  swœ‘ÖÚiYßÝcŸ«twTj Skª²¯¥³Õp`õ³Â¦'¯3Wo[k÷–¢>ßà¨è9÷þ'549w=~¬¨vû+ßð×û·’u!*æä4sa13äÇì:löA#  On»¥»Óg(¹”ÖÓ\ ·twp-cÖè<¥s€k$eΖˆ=¬tof4ÌgÄP:À(À8Ñ‹‰EF¥ƒ¸Z’©èFÂ4® ‘Hx+5u¡ìïPî6êÇ ÀÈ1™Š5ÝB¯¯h×p•M/‘èË$)€¸ Hò“ò'ÎtŒÈŸó#ÎuŒñƘ/qÒIŒ Ä©²²tµGµ,®   "Ê‘[‹Ÿ:¯ÅÍqFÜX,äûÆ/¸ÕÒÞ<Ðý­æ.ßC?Oå_½Þ•|ö•úèܳՂèÙ/Å‚¨RO¸úá¹5;Ön=øÑŠ,"îçdWvôë5SÛ«o[üÄ[“DµFtkP€QFÄaÙ%" ²;vêH]î°«‡–LYÖ¾ÙöŒÓ!wg×IÈ)š$ÿÒ9À5r LõÃïC–2~i8#öá»v×$‘¿“—wÓ¬;VØ”ã YùYžR:À(’lר_$(Ňi€1 nF±žµY‚˜F­#‡UO¢¨SILÏI "ÎuĘž3Dœéˆ1=I<ˆ×qbz"DŒéNzNDÄuDLODÁD¤$qÁA$ $§ã}‰8qbÄNè)Ĉ߯ˆýظ‹q"bì]wÿyŒ4&IK¸“¾Ii‹¯ˆuoš‘ÕÚaYßm¶ÏQ:‡R4ÁK{àogëë 8øÑótôË7:SιJ5aÆÓ‹úbóΟ›^¹í¥{š%›-ÑÙ~må§¼{ÝÙû Ÿû(Yçoй3#€ûˆNFàZ]vƘl1ŸÕ>¸ï‚ƒƒeÃ%îuº­½Ý=äÂåÖ8ã(æóR©© µ‚Vó:i”ήÁ$Õ D” to&9¤—‰(Jé£Ì”ÖþÆ»ˆèÏJ®¼¼›ÔÞ$"¯œÂà±+NÊ*z¯|ß+•ŒáÙ•cHÜŒb½ÐÓ«ãVQÏD䢂» "Ò©ˆˆTŒH²Ÿ (q’~˜˜É¾¯:û¡úŒÿÑÝñ-ŒˆŸ8_ÓeMâe.= '`Ò¹?'ATU'e¥(hD´wX7|_ÈçÚ¿^Dî ªýC¾½­¯Ûp`ÍstäË×:RϽÖ'j cžûÇéi4ιû•žoŸ»s{gÓdgûõu´f­ºîœŠ¥ÏêáïÎŒî ªT²Å|«Åe`Qt2f×! ê«ë •Ö^y¤†ˆÒåÖØ]oeÓªdNþ¿€÷I2.aD)À›%e]ʼn)`”z$1'uÅÞ5^ýrk_ÃĘì¤8ÎøÓ 9‹¿¨Úûßv¥2 ˜`F¨àNGDAÇŒªï•Ì£ îF':¢tW óçĿɮ¥/'¢8÷&š©~ÛÄ,‡Ú{¥ÊÛgq‰~ÒÁ©§×±¹ÓlŸIÞô—m¨´¾²÷Ÿ1¡Å'8´ÂÜê¼àíD¶ÞîÀý«Ÿ¡£ÿ{³5cÉ͆дÉûú¬ÒùùͺcEÞî7ÿ°®åðŽyÎö³õ›W-;¯¦à™z ±‰nŒ0â•|g>«Åæ²çDA Ùs8ìöA£¿·O¶Ÿ¨Ó÷ %—’ZÊJÚœ­©mèÌç’Æ¤1b÷*\#9¯ØÀ‰þ©to–––ÊÕìÉ1þó ÀHÒ’ÄžgŒÍçœË^TïésòÓ‰ ÷)`‹%ñq"Z¦T-*WÜ‘ŽýØÝŽˆUDl,Ü$6ÊŠùl¶þó1Ù1qã^æ樃 [—á;ˆ(ÚÏ@qþÚ Oïjøq”n_Ÿ}wK»e‰JeôL#?‰Boñ«_On:°»lÓßì쬫œHDgiién ÙýæÈ'4¶!«àÖpC|ºgŽÁdŒM¸â¡yG¿zksÅúw¦‘Zn7‡Ý»ú¶‚¦%O¼Sœœíæ”#FPke í:;;\öÞš òùìvû ŽÓÞÞ.ûøÔøøö!–¢Z—:k3ØÑÔÓ3ø™ç 8îž#âZ¥s€kHý–?±¥sx3«Z|‚ˆ‡)`tcsY…—Ñ[J'.¬ ãõ 0B8±ë3ó_ªØ¿æ[%Îb>PÄ  î8×'=gÄÓq.ébA4âw0Ò8³ª1»Œ³%²ušŒuÅN™or{ A:'Þ°9/Âwî‰ÛRu³3‚u%Úú³-V©¬¡ÝšÈ𑈈˜àäe”s‘ˆ(•PÙÛR¹ý_÷Q`|fmfþ/b|Âb]šÛU’\6Ó/2~WÉÊ'’ˆsƒÜ>’ݾæÎ⎠þôïúˆÌ‰QîÎ0ôAaÖîúêS¶746º¬h[tRÌçp8u5@uMµìi¾Áa^7·íX™la"#BW>/”d*\BD ”ή‘`*ÌdÄnÀåËCgÌ*ÈeŒ]¥t€±ÿ15uáûeeŸX”N2IYcsμ' #QxŠ16C‰.ž(æ§RpGÄtœIz"!hw‘0¨‚»ïwcßÿ7#†‚»Q€õTíû¤Qé®Â þ åÖ¢â÷‰*õtwgŒ`ذ$9h<Ú6’-I ê,i©­­oî÷e$_´5ÙÍí²OELTý¤ãUHRfÊÒç>Jé¨>ZùÕ_î¨í¬::Щ¯ãØþ˜-ÿ¸Í=a~kêy×E¨õ~.Jî:™3s}nŽ:²mÅ=f.9d»ïqI üôþk„ þüŸºˆŒ\t诗ěöïp÷©ñáÚËý¤àN£Ö‘êW1UI’Ž ?Ð ©à.@E$ž±àŽø÷Ew(¸ƒÁcDeJT1”  ÿ\¢P¹µäE²EžäÖ Q•ŒhšÜšA+êkëûú‰(Ùͱ<š­»Cöý²¨ÑöÉmŒK6.}v±­ü@ùW¾£¹»¾z iæ6窺]_E4”lìN9ë*[Ü´EÁLð¬ Çþ‘I)3~ù\Ý7ÏÜR!9ì‰rûpI øtùÕ´ð±×kÂÓÇ{f«A€ŠŸ<×pø³U§lgDº×^ûÏ·Ë–Ý û\:*•üã|0ùžþŸûˆHö ½Ø¼¹^7ªÃÖgŽ—ÛÎ:óyã~ÛUD”¥tpä̳I ó”ÎàÍ’LùDl¶Ò9ÆÆø&SñŠÒÒ•=Jg3Ù~ÁˆR”Î0–0N¿gŒ­vw] ŠùÜ$nF±žµY‚\YpljüU?TÛ©ˆH²1$’Nºfwà¡;¢t—’h‰“{ÊyK3ÜšeÒ‚ô¥az•Óâ“®^»…£ïvs—ìrjl1ß‚“2’.zamRã]GÖ?qO·¹©>÷Lç’ì6ÿß½L•›ÞoÌ\ú«€Ð”\ýPs}PDôÌ;^êøæï7²[ûÓäöá\ øäÞ+ù¢Çߨ KË‘ý à ¢rg¦s¢~¹‘ã¿ûÝoý®¿~gŒ ëâ8ÁIg>É>ðb¾—^\"·1Öž>~ÜP³)ÁÜRßHœGÈ­q†Î|Þ„1ÆŒYw+\GèA¥3x;Nl¹ÒÆ `3Ù–Ñß”r&yy7©‰øJçƒ2³ò—ÑwžÅ|'©‚;F¤R©˜K îÐ>¼“x™Ò\I"Z"÷¸ÔúîÓúŽw{ A¸*=¤ÿtë¥-}6weñ&ö¾nÜv•¯¯u ·ÈÈM)þ×—Tûݦ’uOÜCÖž®ì3ÝÆjîˆØýÚ£—^mºøîX!Ôc^´þ3ïú—zËß~^bë“¿/œK†ï½‚ç?µ²98)Ãë:ƒ©4Z!*á›®úªSƧwuvšùͯ7<ú»ßËvÄ(Qd¯·hg¾>ø`gwwÏD¹µ ã¸ýLf'Ÿ»5Ü[KD²Å|‚£3Ÿ1fqtå%³ §£yJçðfßw¹¡t€±ˆ¿Ëd*~®´tå€~ÛSJ[ÓψHvz Œ,NìB1ßé¹¶àŽëˆ˜žá³ãy IDATˆB‰H‚;÷âD£¦3ŸÁ`H¥Ë­ÅM?»ÃÝy#ÞO[fЩòN·ÏWG;d (Æ:[Üv­ŸÁ1˜ãÄLš•}Å›[øOÞþvÛK‰’ì¶„3ݦ£ú`Üæ¿ÝП8§¸Í8çâhAôŒ—tµÎ×wöÝ/¦lú뻬æNÙŽƒ\’×ÜuióÒ~ÜkpwFW˜tõ¯´_ÿåNÙµ¿?ó÷i“&OÞ•Ÿ_pÆ®›Î¨ùÇ´C:s1_oooïM7.sZ,;á²_È>wy²Æý»œ9°·õôŒš÷c‚Àå8ຣ_†Gè¥3ŒUŒ(¶‡[.&¢7”ÎâÌñi'…˜v œ©Æœ¥Ó+÷¾ÿ»N8¢¿ü³âb1þ`o€Ú®‰‘˜DÄ#‰1×qb¾ŒS1¦ã$ù‘ÿñB<ògD~ÄIGĈ˜/afø1°Ë îPj $6ŠŠù/qöÅeÚ¢+ÎX˜¥¤‹Ç×Qª³uΩé`s¯ìØÔ±Nê퉒ÛnˆNüרŒ±ŒE—MwÞE¶¯þuË5¯ep΃Nw.q]ùºw¢«·®­é½Æ,¿AŸw*­~Ö/¦müë²Ý¶Þ® rûp»=lõ-ùÕÅÿú\¥ õuwF€á2ÎVòÁQ¥ƒÈa$Ü t€±N"vIzzÁ~ØíŽó)2“¯zËÊ>"êsõqãfë…ž^·Šz&ò EJbzNÒñ¢@Æôœñ âLGŒéIâAĸŽÓñ bL'pÒs¢ FðÓŠ(`ð8QOÕ¾O•Îá AŒä»Ì„eä#"-æ[”hØÉˆæœnŸÿióŒù­Æ\[VMDYrk¡©Ù!Ã=¾Ö×ßpΣ+ævTWûü7ËÍ- “ÏpÖrä»”O\×™Qð kÔøùNGlº‹ ÖøÌºóÅ”]¶ÇÖÛ-[ÐgíéJùðÖÂKW|š!¨ðW ¼KdÖČԳ ו}¹zž“]Øw;¶ÏIN2¶-]ºtÝÃ<š9c3'Å|Üqjg¾òòòšGyøÈ?Z“Ç9Ÿäì˜úÀéç_:¤NJ³ötÅÈmgœº; +.Ä®V:¸ÆñQ_×+À›±âbÑÈèZ¥s1K×Ño”r²8Sq°Š¨Héc#ò³ˆì""zÅçU¿šŸP$ØNDu®<öéº  ­› 'òg£ìÿÀ@1¢#œóÁ"õ@ç 9“,§/¹â´cR•63& ú »8ÖUv§»%Œ—é©:ØédI ŒO’-:ŠÀ¸Äÿcï¾Ã£º¯üŸï¢Þ…„$Š(„èL³Áq;nqìØ)NϦmâÄ?;§g§y“l§lât·8îݦ™bЉ" BPïÒ´ûýýáà`¸£:sg4z¿žgŸï¹åƒ-FíÌ9ãnú¿—ÇU¾òϽ[~zo¦é÷íí|3àK+{ì~©ÙöÌáÙ·Þ3ÙœîU–Á0\qI}îW7Ý÷ÁßÇr]s{]mñ ÷~tßåßþÕºR ª-ÿÌ·VÕxskkí±¥ÁÎ1M3ó‘GYýÈ# ƼqՕ뺮¹öÚœùóçOŒ‹‹³|ÓˆÛmý%b·Çã~þùçöîÙ³»}ãÆ æ¾½ûr:;;§ŠÈ˜Þr*¥šÖýèo–«Á£]O{s“Ö¦å$TShæ.&”z.Ò†êW3+¢ßÄ’u Ì(~+ÃAá!ÿréë÷4l D]/QØÌçÔÞkD1| ˆZéë„f¾èb÷4Á¡6 ŠHºˆ°Â ÑK©ØY±«¬Wì*¥šòç.›nwœþš“¸Çm¨^Wì¶{üem? VZ*vy­ŽÇi‡;>äß|O¾äšÙ…Ë/ïÙü“»_«ÚôÜbyë5?¨ÖÚŠ©›~xGKñUkÌ_ð®I¡Î3θĔeŸýEö–´Ú øÇYsjßöYoüö¾Ò…wÜ9Ãî|À(¥®ýù?<úñ«¶wœ©YÜÇÙŽ“5'>ðÀ/å~)"¢Ç™¤¤Ä†¸¸OBB¼Ïét™""55Õ–WÚZ[gÜüž›Q‹ô¬¹û'ÕI£òç èÂ(Q_¾¯FÞš–}­U©Íq0H¦Rëøæ vh1®‰t†=Óäó)=¦Oš¾®èÈ'£ëwØJ]#óq€aO‰Z3zöeI§÷>ßîgÑÌaáž&†µÃY"âeNŒX•‘ J)WzZÊåVµôñ“*Ãy‘Ý™úëÝ“2=}³³¶£ÉŽ,ÃQGõÁ4«ã £NJ˜ÞIçtÇůþÒWÏ¿ýs5ÏÝsÇÉŽº“½6i3~àÉÿM?±ã™ƒso¿w²;9ÃŽ\ý—’™µø“÷wmý٧ψ6s­Î)}ü÷Ó3'™´úêˆ6e8®ëõôüçÿëÃNïcÕ.U@ ·­­=W¤=Lét×ÊÏÿwÙ¸%kûZÕµêÊö„Wfk šR²>Ò:ZäÚHg`øÓ×0›ˆÃX'"?Št޳ ¬Ot‹qi¤sx[B‚/ñRy<Ü¢™/† ·µÃ"’*"] ™:ºÞÕ0HiiI+EIJ©kÒ¥76Çé·Œ8穬x炾Î{æpó°\ÉnZ›Ú×ÙbÙp–3uvW¸ŸŸ2zì˜û˜ÃÏ?¼óõ_|cœÌœÞÎo?s¬xÓ?ÜV|í§ÏäÏ^Ýë*ÎpJÊ.»ø£ß+ßþ«/%Šè‹SŒM?ºktæøÉM¦YNᢕáp:¯øÎïW•ýó[ßøí¦hmfEA¦—}ë·í£g,¶|""õ‡Þ4ƒ”ªêêê:l ƒA™8ýúñbÈ”Hç@h;~Œˆ”D:ÃÙ¤â릈CMŒtÿ¦D]*QÔÌçî1VˆHb¤sø7ehšù½ì^;<Äi‚ñòVà ¢ˆŠ•É|Z­ ò¦)ïÄÕë£v]è{§f•‹H¯zZäÌá†.~ñn¡«öX•hmùMöØE«“ìÊ1õ².¿¼íÅ{?º­þð¾ÅÒË[ø´H=ðØOROïyeïì[ï™épÅE¤Ù4%¿hÊŒë>·»ô±Ï‹n­uÒS_¾­éæß¿–æJL¢ÁÃNÉ5ï_:ùâkš_úÆ'6Õ½õ÷2» æ^´ã’{î_àŒKç‡TëÉªä ¥ý¶Á ™Ê¿\ñ.s€˜á˜+"€á.àÔ+›3h³LÝt“C?ôP ÒADD”–•š«QE‹¶e3#Í|ˆ*Ãpíp¶ˆDlmåp¦•/&&ó‰Rë¬'¤g•º“çÙ§?âª}ZVÂܾÎ;ÚÔ]¡E,סŽtM¥›O‰ˆe3_ÞÜ¥…vf‰KJI½ú]R½í¥ý¯~ï‹i¦ß7®·ó›ªöÍÞøý÷Ÿ˜ûoºÓ Š"òßwôìUóZjm¨Ùñ¬å:Rwרg¾ü¾²kþç±Q|—‚á'.5=ãªüuEÇ™šÚÍ÷¥âtéΙZë°O›4ŽÚ±‹/._òÑ»‹³r²î7ªu·´X7Ÿ+E3ß0¡”ØòÍ-ìb®`% C£´ÐDŸÔ‰‡¼Ó%JÞH®¯@ôQ3'-¸)íÈ·ZÃùšù0b õÃi"µkZ‡B‹t/{æL¤s UVVJ‰ˆõôq˯ŒÚµ×MÎÜ­Dúlòx±²™ÏA4íÛ`¹òÑp8j3FåÛGDdÜ’µ3ß÷·m]/û3NîÙ²\zY[ðöŒÝùÀÞqËÖíšò®ÍDÃÜ´«>¶ªãd妖“–߈4+/Ùþ›ï•-þÈ]¬íÁ°•œ;&ÿòïü>ßôû}G7>óÆáçên¬<ðz&ÉÐóåsºã«’ò ês§Íõ­}wnÎÔÙSE©ˆ¼…‹¯»£C›~Ëf>­ÍR»ó`°ÔÒH'@H-‹tbŸO(¤MµT¢ ™oÁ‚¹Dda¤s¸€¡=žE"òb8B£ÃlíðÙfÁˆQ"•Zëa?<\¬§ò‰ˆL»òæ vfé/‡R¾Å£SŠúqªñöâ°¦:j+ ¬Ž'ç?&"k¤qÄÅ'¾ë¬:µw{Ù _ûX’é÷örº»úõ'çרV¶è#÷Mt'§'Ø•ó¬yw|wɦ|èM_WÛ«ú'þ8uüâKjFÏZ4Æîl@(N§kò%ëN¾d½ˆˆx;ÛÛ<ÑppoK[Ã)¿·½EûÚ[•·³Óðvµ9 I~G|¼Ž‹OÒqi™Ú™¨ÜIÉFJîXwJN~RjÁø¬„ŒœlÃéœ""1½½áðþc"b¹¶ÞˆüTÐ7uÓMŽB‘i‘΀ÐP_ì,”4¾¾` ¬Ot‹1)Ò9\HkC&ê{Îþ>kZ%B30rEó4Á!­V*&Vìš"×XUr¸ãËSòÆFesÅeãÓ¶J–÷u^‡×,ëðúgÛ‘i¸ñµ77iŸÏ²YsüÂU~»óXÉ›½¸ä}ßÖý½ÛxºtçÊÞÎíi©/Ùøƒ;êæßþÕÓgÛÚ„j8œ®¥ŸúŸ±›~ø¡ÓÚ4G[œâ|áksÝüàŸ;9••æˆԼ™‹Jòf.Št”¨W³wk³Õq-ÒÓØÑQnw ܤž QüÐ F6¦L‘¸Hç`8suÉ41bs;0ÜiC¦G:ƒˆˆCdú°ŸŒÄ,3ìC‘hæF¨pMœ1ã¦dŸ¯'¾'ÎêоdÑF¼2uª©Ì$Q*A‰‘*¦>êçÚ-%%%ÇéP–]ùó.ª•(œ”¤DôÚñi–« Ï·­ºµ%Üy†«†½¯UˆÈb«ZáÊ+¢f½¥ÃŸpÅwÿ°ò膧wnøñ]c%`æ=Yrvýá^ÿ˜ÅWívåGlmât'§gÍß×Jw>øÕLqŸ_ø¼¹ÏÝsGÙúûŽŠwB°WÝÝ–?¯P¢h­vçÁÀ™¢˜ô C´VÓ­ÞÜ úÏp¨éÛ”–¨øy¦–èh*`%ü¿÷ ™@H•–>Ô!""Òé,áätª«D[¿kjê•ïͱ;O,+HÙá2”eÚùž­léWÓßHÔ°ëÅ«ãJ©æ¬‰ÅíÎÓ—‰«®Z0zÆÂ†'¿øž]] uó{9ÕY³ýéÙ-ÕeÛ}øûs §Û¶wÙ§Oš5£`þeNîz~•U½ñè’²>x¨äšÛYãŒ0-ÕGÓ, Z•Úƒ¤•L~8 3 Qùò€¡ÑZEÝï¼­`„‹ã«ª^µü} }ÔDá«@´šî0¾CË:Ëã£.wÆ‚¨l8Z?!#©?ç™Z×U4t…;ÏpÕvì@Õñ¤ìÑÊ0¢òójbVNö{~÷ʼ’koß(}Läì8ulñ†ÿ¾ýxgCÍ ›â‰ˆHñº¯ŒÏÈݬþÆoï+訫mµ3€ÓZ{;Û -KJhæ.´ð€¢µDÍV†-ÅçS ªÅ§Žt^'€h¥Er”úzXû¢²é¢™R*^D.µªeN+W*úºfJ|3ÑeÌèϹUÍžÃZ„)¼ÍõõÚç±ì´»ä’¯­)¥Ô¢;¾¼òêïÿ¹Úp:«{;5à뙲õgŸN=Sºy¿]ñD)µäã?œª G­UYk3å¹{>pBØ=ŒígNž­-'ó9´²ïõ C¤hæˆ-|}ÀPñæG º©@Ä›ù´Ò¼NQJ‰8'ÍÚ›ÎgD]à D»ŒÔÔKD$Ùª6åŠ÷¸mŽÓ/7MÉ ô÷Ü*[\áÌ2œÕíx¾\‚4:­½Örb_´UˆJ,§ò‰ˆžzÅML‹´[¦dUʦ°>WÞlÙ¬‘†íOV›|ñ5ìÌjIÙy¹7?¸qjfá”^pí÷tMßü£xÛO­ w¦´qÅÓ²‹æ¿¬¾éþ{’ž3Ü9DFÀÛÓmz}ã­jZ Í|ÃK¬5ó®*÷ÃH‡@d;ðè)%ro¤sYÚ顤D«rÿãAæ@xÄÖçSQúÞªÒ§OG:"ãØt÷DTE¤s„’©ˆþLÓápÅÖk„ˆh‘ÏŸÞû|g¤s 2œß"[C…áýÝÍ|0JY6ó9%fç¶;No½ebz\¿§ò™ZNU6u³6,ˆ†½-¿pv8]ÇRFØ'Ô.—ûšÿylyÑšk·J/ûµ(ØþË/d6Uí{3ͬ›¿´Èp8, ðy ¶üüë¬ÚbTCEi•ˆ8¬jSõ:I³g´¾—fòÌpOD:[+ÔLç“‘Žy´¨˜ú|êrÅñùtÓ=1ŸŽtŽP2 Ñ¿£€/¦^#D¤õxÙ¼ ‘È©¨x¦M”¼é¡Þ× šù Ÿ²³ DdžUmÌüglŽÓ«@@¯”Ù­”Jêï5•Ý•:Ö~ "Ú4=µÓ¬j™“gÄλ·•’åŸûöÒžk‡ ¾6ZëÌÝ¿ÿ¯ñ ‡wì g×PòîÏÕ«W¾ödqÇéš¶pf'÷nk R2¸8šùAŠw™p•{«Ó"þHç¡¡»›¢êgz CÊ=7D:"K)©tD/-RϤ!ZŸŠt„á„f>è'ÓïX'AšÝ¦^ý¾<›ãôª¾©»|EaêÔ\óâ‘–˜Ù*mGö­Ó­j“×\í²;O¸M¿æ¶ÅW|ó7•J½}žòæ_¾3ãä®¶‡3KîÌó³ ,×í*­ã^þîgþò€ýΔî´üÆ^‹”×××·Û8K)3èôZŒª—)Æ`xIJÅ/-¢µÖ|Ÿ<â)>”áknˆÅÇÁÐÌý¤•ºÚê¸2œµ£¦Ìšbwž`ÚÚ½[–¤¸JÀeÞ ÇZJÂj˜;½åñ`Óá–]> ¦Éá"oÎÒïþÅS݆;îh/§Å|âçßöÄÖpfYðþoNQ–Í;MGÍ=½ÿãá|>û5;’f]Ñ»íM؇f>è‡Ñ£G'i-k¬jYE3*E©¨XOë7õ™¦6_Éûæär]«Ç_Öí3“Õk¸k*ÛšeuÜ•X—šžaw»¤Œ{óï_Ít'§îëå4Wų¿[tôµ¿mWwjfNÞœ‹ƒ­ôU¯Þ÷…á@ìÐZ{:š'X•”R4ó fÑÌýàñt]ªDâ­jÅWÞœhwž`N×õËNryr’\órÝÖímáÊ4Üù;Z›ÝíÓ­js—Ÿ±;ÝâRÒÒoþý«“²Fíìå4ÇÑWÿ¶ìØÆG¶„+Ç´«?¾DŽ«ZOsãÔ#ž>®g°WKMUµhl2_o¯EÀ°F3ô‡–uA*c–¬™ak– Ú;}Ûüsñór‰ˆs ×>[Þ4&L±†½3Ûž< "«Ú´+nʶ9ND8ââoüõ‹³ÒòÇ÷¶Nר|ùOKŽo~üõ°dp¹ã‹Þu{u°úÖŸÃ¥M3`³“»7Õ)i¥\oÚ°Í|Ð¥”!"WYÕ’sÆ”:Üq–ûìк½±ÅWh(#°tljÑ@®5µÔmê™®lÃÝ™íÏXWJ5žµ¨Øæ8ãp¹ÜïþÅ“‹²&Lëmúž£âÅß/ªÝóòáÈ0nÉú¥ŽøÄR«š¯§sbù ô¶À0Qûæ6OÒ‘¦¦¦V[Ã6¢™ú™šºHDr­j“Ö^ã³9Ž¥ºÏ=úÒIé;†Êȵ•Ý•áÊ5ÜiÓ t>fÙ°—>nÒAe8,'öÅ*e8ëïdÙè™ 6ôršóÀãÿSÒTùæþÐPjÖ _:~oÇo¾ŸÌt>`øk¬,Kµ:®Dí²; `'šù ¥×)™“Ö\?ÕÖ0zzÌR7°\Dä½³²Ü\ö|ESÄ' F«¶C;ˆÖ™Vµ¢5×)»óD¥Ôßùê¼Y‹{ièS‰»ÿøõñíµå¡~|VÑüY=V5¿§{bÅËõ3ØHkÝÓÚ4ÁªdЦ™1f>èƒYguÜœZš˜™=Êî<çñžiìN#7ÙU“•èš;Ðë7o+ G°XP»ù±¦ ¥À¤µë§Ù&Ê\þ­ß®Ìš4}Sð3têŽ_ß•äéhiõ³g\÷yg°ÚÎßýÐêç°Okíñ­u†u•f>Ä6šù ãEd†UmÌ¢‹›mŽs†ÏëZd’ˆÈGäUŠÈ€&óµtûK»ýfRXÂÅ€–CÛ-×+»“SħdXNì1”Rë~ô·eiùã·;E›‚×ö©S¦ßç å£3'Ï™éNJßmUót´Uïx5äØãäî-5AJZkãM[Ã6£™zcšÁVìÊ´«ncg”󺶽÷PDÄ¡”aArñ@ï±µ¦½=ôÉbƒ·åL]ÀÛc¹FyÜÂÕvç‰FÊp8®ùéãó3FY6Ö‰ˆº;g¾ñë/¿êgϸîsA'ðmûå7»Cý<ö¨}óõ`Í¿ÇZ[[ƒMKbÍ|ÐeZ®Ø5Ι‹'Ùç\µ ÝÕJ©$‘Ë‹ÒwJYN‘ëÍsåcCŸ,6œÚòäaQVµiW¼'Ïæ8QËáv»¯à¹iî”´ýÁÎi?}tyÅ‹ö²’wà2'Ï™éJLÝkUë¬?3»ñHÙÉP>€=ËKS‚”X± €˜G3‘•••*¢Véì IDATYÕr¦Í=jwžsuuùwüzÉÙ¾qf΀Ö늈´>q´Ù31´ÉbGÝö'ã- £nÔ´ÙSlŽÕœññ‰7üêÙ±†Óu,Ø9Ç6?¶°¥æÐáP>wòÚÛz‚ÕvüæûÁVuˆbÝ­¬Ž+šù0ÐÌAÞËEÄmU›rõ{SmŽs.o}³7ãì?d&¸ê³œsz“ò†žªÐÆŠOWgOSÝL«ZÖ„é‡E)ˉ}#Y\JZúú?dŠR­Vu%¿ëw_q™¾îŽP=3ÞÚE†Ãyܪvºl× og{Ðf?ѧýtMÖ:ÓªÐjÝy»ÑÌA(¥,WìŠR­c¬žasœ·5´x^עߞ\ôÞÙ£ˆˆs ÷y¾¢)!¤ÁbÈ™mOïW"–“ù¦¯»Õzb$£pÊÄU_¸¯RDVuðOÜù»ÿ YCŽRJY|¥e3Ÿh´÷ï¿z3TÏ~µon:QSkÍd>Ä<šùÀ‚RÊ)Z®°ª¥0œN—Ý™DD]ßÑå_pî±K Óóq+ÏëÕ­ÓC+æœÞô¨e3šˆê(\q¹åÄ>¼eâª+æOXqù¦`õ¶ÚÊåu·†¬¡oò%·Î¥Z¬j‡žùkn¨ž üNîz=È4M]ÝÖÖÖ`oÀ~4ó€…ôôäe"’eU›¼æ:msœ·nê9,Z’ÏþsINÒA·SMè}Zºý¥Ý~Út±Áôû|]gª-öÒÆîwºÜLæëê/Þ·2!-+ØTŒ4óÀ9J])"N«ZñU·dØGZzÞТ/h"»tbZî`î×Ôí/õøL&óYh)ßu@c­jW_ÅŠÇ!0œN×¢ùt°úá'ÙŠçä͹dŽˆj·ªzê/Ö+<D\íþí"Ⲫ†I3F šùàJÉÕÖÇUãè9Km]³0͆ΞÀ‚ófÄwÓsÏ-ÕíCO›jžÿ}C’.Ywûd[ÃÄ â+o^ìJH½°³ádõPŸáp¹ã³ó-Wú¶œ¬* ø|Þ¡>@èØöZSR{ssç[ÃD3ü‹RÊ­E_fUKWtH‡yê½DKòùÇo™•3覧çÊ›Æ-UlÒÚ4[+÷Lµª¹““ËRòÆØ)æ(¥æÞúéæ UGÅ‹8ŠÇ]µ«uRõŽWYµ D¡Óû¶Ç)½¡µØˆ šùà_ÒÒÒV‰HšU­è²l}½ôù'¼ÞÀR«Ú‚‚dËU°}ÞÓÔ'ª[{ ‡,F5—nÙ¯Ms´UmÒÅ×6Ú'VM_wëbÃá#ÎÅ3EÄcU«|áQ˼"«íôÉBË‚R¯Û›ˆ,šùà_” X®ØÏ„UëfÚ™¥®Ù{\‹¸Î?>-;ñ°ËP“sÏÃõ]!™|‹N¼ð`kR`ÖõÔJc\HÇÄUW±ªi3PÐxtï'箸$wRšåZÎÓ¥;™°D™®†SuÚôçYÕ”Ò;ìÎDÍ|p–Vë¬'df—º“.Xw.^ŸYíóšK¬jï™9êÔ`ïû\esÒàSÅ.ø;Ž••XÕÒ2÷'fåæÚ)–͹åSÓDÄ´ªUmx¨)Ï5mQ›Õq¿·gRG}mC(ž 4NìÜ´ÑÜçÓÛíÌDš3ÒÄuÓMŽI‡ýY~S²Z²LíÏÒ†Ê2D²L­F)­²DI–v>lïU‘Î{¾ÌÌä™"Æ«Zá²+:íÌÒÐÜsBDÆYÕfç%Mäm»·Õ´O|ªØÕ°ë¥7µÖ ¬jEﺞµ¬!–’[0:.%óMO{Óœókm5å–÷U°à²ü“»^´¬Um|®bæõÊÅs ]õö =AJGÚÛÛël M]|±³°>íîH玕Íý–Ö÷Z¾¡`8 ™€¥ñ³®Î€‘ïTÎ -f†ˆdh¥3DŒ 1u†*Oi¯E2D$£P$' âPò¯‘[Ê¥E´ˆ(¥ßº±Ïõ‰ˆüúb–SùDD¦\uó`èÌçÓǼ>½Øª6;7©Ìe(Ë r}iìö—z|æÂ¡¥‹M5/ÿ9X#‰¯äÝïÔ¿oônâê+[>ù§ Ž›¡§µîT\ZŽåÊÍþJÍŸ\¤ Ç)m.¸Ïñ-Ïûf^ÿ¡¡Ü@ÕÞŸnYPz›ÍQÀŒñŒr‰x¿éÀüùµß• Ûa†šù€7vÙM ªÉ“¡:ã‚Æ<­ò´H¾ˆÎ0DeœmÌ‘\Cœ†(óÜïw´ýVwžÖ¢G'ŽÆüÁBo½ÕA§;¡¯Ï¬òúÌ%V5%bÎÈMš<Ø{?_ÙR8è`1¬ndzoжž„X|Å-Q÷±+2 §L0Î3à¿ ±®ñÐsÂʇüŒQ“ç:ÏT_pÜôz&t55´$ff[O`›Ó¥;‹¨Vµ€Èv»ó‘F3!ïhÌÓFž6$ÿÜ5¶ïlÌÓy"ªÀ)wacÞ¿wÙÚܘ7ª>Ò ¬¸\ÆU¢Å°ªM»êÖ»r46ûjDô«Úü‚äRC‰å¹¾øL}ìD«§pHábTÍ‹±µ¯Ez¦^õÞAýûFÿ$ç«n«9zA3_{cmF(î?ªxÙèc[þiUR'wo¿>åñù-'ĉˆÜ8#»e°÷>Xßu\‚¬îÉ|]m-=õÕó¬j©ycßt'%[NIDhäNŸëk«9zÁñ@O×xÑZ‹RCzáK3e²ˆjÑ©ç×NìÜÔV´öº¡Ü@œÜ½%Á² e—Öšé¨qhæÎ3~ÖÕ0òÿ½ÆöœÆ ˜¹Ýu Y9– ¯jËs5"bÕ̧Ùdw ÐÌÛ½£1OyÚü·×ؾ³1/O‹ä‹H–SÄ}acž~» „ƼaBG×d>¥”JOK¹Úª–”=ºÔáŽ{œiêŽ.O`îÝLç(ÎIœ8Øûo®n IST¬©úç/Ž‹H¡UmþmŸÍ¶7ÍÈätÇÇ‹Rm¢uêùµž¶Fo(ž‘Q8#èšìúŠý'Æe­¡™ˆší¯Y×"‡ÛÛÛëìMDšù0$c—Ý”`t˜yNÓÌ{íùyJò”–|ykb^¶SÄõvcÞÙ¦³kliÌ‹iZë¨úå|ZZÚ<³ÀªV¸z]šÛ¼»”Ö«‚Õ§å$r(5m°÷¾¢Ùr…áH¦MÓl9´}ªUÍáŽ;2zÆ‚évg± Õ# ›ù|­PÜ>9·p¼ˆxDä‚•ÙgÊvµŒ[²&0ÍÕ•c­Ž"íÎD šùðg^“+™*†ÊÑJRÚÌÒ"YZT–R’%Z²´H–%"ig?€Þ±ÄöüÆ<ºòð/JIc¤3œK)½>ÈǧžzÙ E6D0;ºü½6Û];5û´ˆ ª™Ïgê£5­žAOõ‹U »_Ú­MsUmâª+kDd’Í‘F,C”ß´®ô6¬²ß”áp:\ñG¾ž š7ëË÷[?@Øu5Õ5˜~_¡eQ+šù0bÑ̇w8ºÿŸgDäŒå*\Ñ¢Œ e½ 7[D\Ž(§µœ‰t†wÐæÕggAžËt01;/ìÓÙÚ;};´–%½³  )°÷/«ë""ƒ|Dçö“í%ƒ¼6fÕ¼ôà^YaU›ùîølŽ3ll{à;=õ—™ZëtwbrÙ»¾ñ€kÔÔÙS†rOÓï÷‰hË†Öø´¬½Î'åŒq4”ï¸àxÀïË­MQÊÕ³ôOcÅ~ëfZ%LåÀˆF3"ê¼À~@`†'¢Æ €Ñ ½ªêÕžH‡8˵Ϊ“Ïp8jGÍô4¼þjnñ5õuÎÊñ©–SËú£®ÓWæóëEƒ½>VÕn|4˲ TëÔ+n¶œØ7Ò½ù·_n>øäŸWýgoWGÉS_|o×EŸþæö)ïº~qo×jÓ4wüî¾ÍG^þgš(‡ž²öš¶y·n©átºÚN?)"…V×%g ÙTÉäÜ VÇ•H|WSý©Ä¬œ°7ïø7¿§»ÓÛÝiÙ ¬µl´;MhæÃ°cS`ˆÄ…í1"é¨Y±›“““,"[Õ²ŠfUŠRùá|~ `žöù½6Ú¥ºÍiq΃}ÆæãíQÓ8-ÚŽî/ôtM·ªåL™¹Ï鎳œØ7Ò•ýóA‹¦R•¸å§_]ÐvâÈÆw|ie°k_úú'6ÕìÞüv#àþüŸ|úOU}ê›@@Kf¾Äì¼­ÀMÍ›¬ÖZ{¼žf>À^'vn*ËæiÓd2F6šù0"„¤PÌ •§•äÓ8*jšù|¾žK•H¼Umê7'…ûù -ÞÃZdUoç\6%ó€ˆ\4ØgrªÝ;i°×Ç¢€§«³½ªÌr”3>þPÞÌE%vg.Ò §6ÖØ´~bÇk«Û~þ‰ kn»ò ª mmg‚5ÍŠˆhÓ?Úê¸+1µR)5{ÐÏc¸â’ä­×ò Öív5Õ1Á°Ù©ýokžÞ¤µÖ¶†¢ Í|@ÑœÒF]¤3ˆˆ(¥Œô´”«¬j)ycJ.÷’p>¿¹µg§h º–TD$Îaôä$9ûlø ¦¬®³FDhæ;GÍóîÑ–kt§_uKˆL³9RÔó´µ4?{Ïû4«ìsý𣯼1þ?o»òDz2Rrãd ` ä¹iE-ŒÚ'e­Ú4/hæënª÷…úY‚ó÷ôty;Ú¬_oµÚhs êÐÌD˜M €ù"’®?C¿¨è˜Ì—™’²ØɱªMºøš@˜¯Û»cú:iv^âá¡L&{®¼)u°×ƪÚYÛª:gÞø1ˉ}#YcåþЧ¾ø¾x3àïתg_À¼àcN)¥¾ýÉÏÜõÓ¿§‹H¿×W'dZìÃáj ˜ž nÜÙÔî¿óÎQýÆ«DdeÑÐÌ€f>` ] dh‘<Ñ’oј'oýÿÐZEE3_Àõ*Hiҥׇu:[G§o§Ö²°¯ó–Oôd2­u×µ¬Œ=GÛ‘}‡=]ÅVµQ“§ïq'%/·;S4«;°çð3wÝ–«µîwðø¼ÌC"²ìüãk•Ì{øûŸ®þÌ÷þXzª±eqîubûÓ+ìÙ4ïö¯MˆOËé³ùµ? W\WÀç¹à¸§­ÑÅýôÏ‘WŸìRjiiéÜok ÑÌŒ65žm¼€%“ù”–ubÑÍçNN=ž5èÕ¶ý [½)ý9qÖèäÄÁ>¤®Ó¿ÏÐa]<ÜT=vÐ5º îøR–Íq¢ZëÉc'ž¹ë¶œ4ò%Ä»ýï]š¬^˜7jÜ?ùÏq/m/Û}ï¦y½þ¾V@«®†Ú›ôÑμY«_›¶þ“K.w|¿ÿ\q‰=¾®¶ Žûº:ÝC¹/€9S¶ËzJª’ÍZkÓæ8@Ô¡™@PƒmŒï dùf¶aêlÓÐÙZT–a¶†/iÿ¤§§OTJ,§ÖYtIS8ŸÝÖéݪõ…“ˬdÄ;újv êå£ÍþÁ^‹ž®ÎöãçYÕœñ‰‡FÏX`9±o$ò{==O|ö†n­õØÞÎs:'æM+» ƒI§ö½¶úté¦E—ðè¸ÅW¯Ðè\‡e“¯»;nÐ÷0 ¾îŽ_W‡ecµÖ²Éî<@4¢™@Hý«°æ_ÿU”ÒëD[צ]õÞ¬ó ÂÛÜâ·œXx¾—Ña(•9ÈçøŸ­h¦9í5Ïÿ~·ˆ^aU›¾îÖz 2±o$Úøƒ;·ù=Ý«ƒÕ eœ¹óö+ª®_³h±Rª×†¿`Ãù…Û¯\õÁkW5}áGÞTz¤f™ˆ8z»F›±åÏüfì± ìž{ÛW’Sò&OèsUf¾@Àdë6€P«ÞöêAëUó­6ÚˆJ}NÒ€˜¡eÕaÃé<ž9±xÐÓðúRßÔóº=¡?çf'ºûœÆnÿž–n?kcÏQûÚã, Ju̺ñÃsm޵ZO;q|ëËKƒÕ“ÜeÏþìN× k/QJ ¹.35)óÿ¾öÑÿwïGޤ%'îíÏ5ÞΖyÛùʼn;¸ó%_w{Ã@žg(§e¯6ý¬õlräÕ';ƒ”:›ÚÚvÙˆR4ó233ÓDôJ«ZδyÇÃõ\׬èìôk½®ˆˆdv`ßþQÚÀŠÝs4ܺ/àí±œ¼7ªhÆWBr²Ý™¢Õ¦ŸÜS%"–+gã\Îʧî¿s|fjÒ`'F5còØ)/ýâ®Ù_¸íŠ­‡q²—8ÛNV¬Ýð½ÛŠç~÷š6ýú˜W.—õß+­™ÌؤþО«ãZd£ÖÚgw ÑÌ`D0Mßå"ⲪM]÷ÞÔpÝÐw¯éòé„Á<+`êãO•7Y®/©Ž<ô“`S dá_¶žØ7ùº;:ê½äcGwýñ›Ÿt%%Ä…µññæw-]ºñ×÷d®]Tòšˆîêó­3o}bõkß¾¥ªîàÖ=}nÖ_î(­ø:°§³½ÍÛÝ5Õª¦”¼dw ZñKl#‚RÖ+vE©Ö‚y«JÂðH}êL÷>­År2\0]¾­uÐ&´`þZZ" µs ×Ūž¦Sµ=õÕ– jîäÔÒÜésôß%–•?÷È^±l"]³hÆŽ ÙãíÈáv¹¾ûé÷¬~âÇ_h)ÌËÞ*"}N© ø=Eûþö½¹[~ü±×»›NU=1Ø6]¥ù:°AõÖ—‰ˆÃºj¾lk ŠñKl1O)å­®°ª¥Œ/‹wûfe%¾9)-þ C©¬ú;ÝгÑë׋zQ ]þ¹¦¥Û·û¡ýõý^å;yøGå"bÙÜ8ç=o³9NT«xåŸA*ºëî;Öϱ5Œˆäe§ç?üýÏ,ýŸ;o;”ç.ïÏ5Ý-g–m¾ÿ9oþé›/š¾îŽóë>O&¢ #û„Ô‘WŸìRªkiéÜgk ŠñKl1/==y¹ˆdZÕV_y]Ë}Ëǵ}dVΜÏÍ]üãÕã›?9+wCArÜÑÁGÛÉã¹VÇ ó³ßLMLH·;ÏYKg¿úë{&`ÝÊMJ©Æ¾ÎW"ñ ».}õ;ïk?¶ùñÍ¢õÛ“ýtÀoý÷ÃáIó.€ÞÕÞ7Úê¸õŠ>çï*0ÒÑü ö™ÊzÅ®ˆïÇw}fŽ¡äí&%’Sœ•°ê®…y¿·|Üþ+ Ó6ÇF¿ÖÞÒtòT÷Ï\1”¸/mYXÓæ}½Ï発ösÏikèôY6cT'^øãNÑÚ²ysâò+N§ËîLQKkðy,×讞7-`wœóJŸºi튗~ùÿœ‹K&n‘>›ï´È«|ñ÷Ë_ûï[÷7+Û/"bz=–“ùⓃM "]- M~Ow‘uU³b8Í|bŸËf¾´ŒÌý™™™ùÁ.Kt3¯˜±ü«Æé»æ½:=3é ±ÏNè IDATœ ÔÖáÛqâT‡øM’©oŸ~ªrAyC÷Æ eES÷¦Û-O<ÚÔ3)Ï‹%µ/ÿÙ²¹Q‹ô,úÈ—KìÎÍ<í­"bÙÜ8}ÒØT›ã•šŸö³»>°ü¯ßþÔÉQé)»ûs¿§kÖ®ÿ»§dǯþs“ßÓ™luN|J:“ù€0«zõ©ƒô{ƒf>àÎH€pJKK+2”XNš¸lmƒ¼Õœ§ú¸Mr^rÜÅŸ˜=J:ûèöÓ'ž>Ú:½µÇŸØÚáÛÝÒáÍ-‹Dõu›þó›Úý…玮œ–xøº’ìSãÒâÜíÞ€ÿ`}§.1Újý»—–ñéƒ@¬«|å‰`kt´´´TÙˆr4óˆi‘Ë‚ud®¸!¡±Ë·-+ѵ´ß÷Sjâ²¼”‰ËòR|õ]¾]ÝW÷J‡¯(`=°oÈ5tMýΆê©a¹y :úðý=Áj‹?vw–Y†§+>.X­Çë÷Û™e Ö­˜»èŠe³|?ùËózqÛ\­%e0÷IÌÊ1C À;5W™hu\‰¼dw Ú±f@LÓ¢/µ:n¸\Užä±ËÿãÉ#%UÍ=›Ed ë6]£]K>³¤`Ñ?nÞøÕÕã6ŒKsr` ZwCMMOãÉV5wrê¾Ñ3Û)Ú¹“S’‚ÕÎ4´tÙ™e œ‡ë‹·]¹òéûïìž2vôfpc^jþX<0j:rðˆ6ý–ëìMѬØÎC3€˜¥”r‰’ÕVµ”Â’jQ¾@êgž>²ü–‡w>RV¿Éã7Ëü‘Ü…cRVý|]QáÃ7|ÿÜœM NÕ9Ôü˜Ê?~爈8¬jóný¨nL‹G\|¢RªÉª¶ëбañ1<*#%çÏßùäò_þ¿–''Æ• äÚ£¯>¥½ímáÊŒtå/ÿ£&HÉ4Mõª­a€a€f>1+##e±ˆ¤ZÕr—¬{ÇD²v?ý{êVÜð·ƒSîzáØÁò†îMZKÇ@Ÿï4Šo(µâïï™®võ¤-ËǧîV¦¼x›¿£µ¹õÈË©|ÊpÖN½âæùvg.ÜÉi'¬Ž¿q *Ñî,C1ú„i¯üòî韼ñÒ-†2ÎôçšúŠÒÅyï2ÏÎ?üäemšpgFšê×_ ò:¢÷¶µµ5Ø›ˆ~4óˆYZ«w)y²æ\tåjY]gñž;ºâÆ¿t>°óôÖ–nÿn`CžR’<>=þ¢/¯;ïÑ[¦W}vIþ†ÌWý@îþ;òÈöŠˆåÊØÉkÖU‡ËæHÃFæ„©-VÇÏ4¶Ìðú|Ývç ¥”úàú½úÀÝÉËçN}MD<}]£µ9jÿ#¿^ó§›®Ùµe_øS#CÀçóv6ÖO·ª)Q/Ùhæ³´Ök­Ž»’38â,¿Îå ˜ñOj\zÛ£‡ç}êÉŠã›·n˜ºv 9\†š¸vrƪ?\?%ó×OݽnZÖV‡RþÞÖL_OOîWfX•j_tÇ]slŽ4¬L¹üÆ T⟞ݺÛÞ4¡‘ïNúñÞºú‘û>Û4vtö®þ\ãïîšþâ×>:ó±ÿX¿¥³þÔépgb]íîMD´åëK@«—íÎ 4óˆIJ)—™gUË,YÖ6ÐûU·z ¿·©fÕu=˜sߦš­g:|[DÄ7ÀÛ8Òœó>º`ôÒÇn)nº÷âñÆ¥Çh¼Óñg~û†Öf¶U­`βÝî¤ä4»3 'ã—¬™)¢:­jxrc¦ÝyBiü謼ÇîûÌüÿþô{v»]ÎÊ~\¢Z¹è¡­MÞôÃ/½ì÷y{ˆQ‡_ø‡åÔOñ$''o²5 0LÐÌ &¥§'•ˆHœU-{Þѽ¯©µsãñÖ¥~¼ü¢†R9 ’WýüêIã¾¹øàûçælJpZ7T!8mšS¯üml²wÙ§ï-²5Ð0äp¹ã²&—XNàëêñÿãÕÛíÎjk•ÌÛøë¯Þrù²×•RÁŒÎ•\ùÚÓkþ|Ó¢ÓÕÛ^z3ìtfÿöœ ¥­'Ožì²5 0LÐÌ F–SùD$V4R(žÐØåËy`Ç©•×ÿíÀ”»^8v°¼¡{“ÖzÀ yñN£ø†’Q+þþžéæÏ®ž´eùøÔa¹Ú4NmùÇ3à/´ªeN}#yTA¾Í‘†¥EwÜi9ÙPDä>ëñú‡ý„:‡Ãp~þÖË—½ð¿_2gÝ("¾®1ý¾Â—¿ýÙÙO|þ=›¼ížè ŒT=íÍMÞîÎi–E­Y± A3€˜¤´[wÅUîø¤P?¯¬®³ø Ï]qÃC‡Ìßí>óJ»7ð†ˆèÜC)IŸÑ—WŒ÷è{§WýÇÒüÍ ÎÆPg%ÇŸøEj’¾è³ß̵5Ì06zƂ⸔ Ë t^ ð‹?ù˰ŸÎwVzrRæï¾ú‘•~ãcÕ©Iûúq‰j¬,]ñ×[/j=½Çþ°bÀÑמ>,Á¾ÏPÆêôôÔ;Óӓפ¥¥ ëUÞ@¨ÑÌ F)ËÕ«q)™amŽóúÌ”h¸ä–‡-üìÓ•U{NulÐZêz·CM¸lRÆò¯Ÿšò«kŠ¶æ§ÆGÞᬩtÓÞ@wg‰U-){ôÎìÉ%“íÎ4œ­þÒ}î`µmû+—?¿uÿN;ó„[ñ„‚ /üü˳î¼ýÊm‡QÓ×ùf 0öÙ»?8uûo¾·ÑŽ|ÀpVùÊ“þàU½F‰|_‰ñ’¡tcFzjUFzêci)_ÉLK»jTRRž}I€èB3€˜dŠgu<~TmëB6{&~õåã«®ÿëôŸo?µ½¡Ë÷†ˆôÒà`ÉŸâ^úËõ“Sr’ݵáÈ9\Uüý>o°Ú’O|%hc¬åÏY:=%oì¶ eÇýâ‘¢Êê3Gm eƒ›.]²dã¯ïÉZ½púk"âéãt÷>¸ò¥¯}|ƒh= É›ÀHÒTuh ëì EäÝ¢Ô7µÒOù]ŽÚôôÔSéé©O§§§}+==åúôôô aŠ DšùÄ$%’ou<>w¼ÝQÄgj×sM‹?øXùÂ>VÞôrUËk>SÈ=”Húÿgï¾ãÛªïý¿¿G²<å¯ÄN<²ãìEvXeC€² J––ÒÒö¶·ŒRÚKKm/m™—ZV€,f M¦g†wlÇÛ’,YÖ<ßß–‚䨱}Ž$¿ŸG<Ð÷HçåĶlùãïùÁÂìÊaJ ;öÚ²#^Kû¼@k¦Øøò± NŸ©uS$8ÿ¡!,Ö¤”I×ÝûX\UCk•Ö]ÃÍûÇ;¯^ùÊC?l•d>p²ãöm]ñÖϾµUªªªEQ8é8ZzTú}ŸƒûKY8_@þJ@¼. Ö¤$'v¥$'~”šœøÇÔ$óÕÉÉÉ$"""""""""""" æ#¢È$pg¶¨˜8]?ïu8½ÿ»½iåå/UŒ{`s]I­Õµ @Oî›oânsŸ;²ú×"ÐÚœî²kœ1âÒÒ3æ^WE°uU•Y×Ýû˜9wè€Âœôqïþýg³|ݹ;ƒhíëØö#—ô›ïoÕª(\zûåáÚE6Àø©âeµ*%9±-9ÙüRJJâM YÃt^"""""""""""""Íp˜ˆ"“„1ÐÍÆ¸Äø¼'±÷xÏŒ;ߪ^zÕšCêk¥í[^µ¬¯û”¶ö½¬ìHâ¨?TéêhZhM1FÕM>ÿªùZ7E’éWܲ$}Ò¬Oƒ­«~™yÝ}%”T6Ö²KKמ»xñ¦¿ÿ<*#-iO_Ç5ìÛº¢lݳ;´ê" Ç>û8UÃÓ¥ ˆk ñ¬×¨4¤$'¾‘œœp–"à°7Q¨ ‰¡"¢ap˜Ïçê ¹Ëb:=ªù…â¶eW½z¨è®w««¶8>Q¥ìüò1v·¯äÿö¶.Ы1”]ýëÙ•oÆ•·6 Å`Ð8)âœ÷Ðó§™K‚­«~™ñßü߸ÕonÝ®e—–’âSßúËOæ]¸|ö'd°ãö<÷ç"{KãpíDFVÜv›Õã°M ´–™–´gú„ÜOb£ËôÃé.P>HN2'''œ9 ç """"""""""""V‡]ˆˆ"@'N\’ï?xºZü:´ô[u§kü½Ö7*³,/qOQF¼³²Óó~•už*Õ?¤ÖÓp´ÚÙv,ð®|CÃÌoÞÊÇ!`ˆŠ2}ó©wÇ­¹é¬JŸ»wBÃbÿ¾æƒ%[‹l~â—7-3 QšFj@!îÿî¥+Æf¥îxlÍG èë&)7ÝwKÅ7Ÿzo´ö…D¡¥jó† ‹­ýê–KbOŸ0ü~Õ_ÕØR½»´¦sï¡õpms‚ÅÞ3VJ™8D)Ó”“““ž0›»ÒÐ ‡cxˆˆˆˆˆˆˆˆˆˆˆˆhÈq˜ˆ"U#€ñ_½ÑÙq<,v$õ©Ò´¥Æ6KMï”rdõ}­ ­]~seÌÕ8)bE›““¾ù仞5·ž]£z¼ÁŽ+>r쌳¿÷Pù“¿º%v⸬ Ç…³›.Z±XU±í‰×?ZhÝÞܰ°¹xwyöÌ…Ó´n# %•® ¸‹¥²pZá”/þß`P “Æ.œ4ntá žø°’R¢®¹³}oEõ±%Õ½‡jã:­Žqª”i§Ú# owØÍSÓÓÓ/loo·Ÿêãi…Ã|D™]ÓÙX™©} …žãUõΖú`»ò5ͺúŽ€ktêbÓÒÓ¯|f³ñÛÎ;äu:¦;ÎÑë™ví=¹V­œ»å¿o¾x¹"DÄí"ùKV,ÝYRõiñÑúå–ÅŽ'~k»üñ75ï" RUUk}Õ¤@kyÙ£%àŽ}_B ô¨ôüÑ£Ò¯8ëߟΛ;­Í{Êkv<ê,­nŒk·ØÇªªš5€´å^¯{Ó˜1cÎjjjrà~DDDDDDD49ðä•ãv½;ˆöíÒWh"""":óQD’%"Àí~Oﯳۗ˜¬y Ê‘Õ÷7hmÚÅ7T¢¢Æhœ4"Ä&§¦\ýüfÓÚï]¸§§³m~°ã³áã}§oÞSQúØ/oŒ<.ûk;c†»Ç~ñí…Ëoyð¸_•_»¤nwcÍõ°ÈÙcÀwO9„ˆˆˆˆˆˆˆú$׬ñxRï""""¢p—›$"(!Ô‚,)ÇÞ~¶XÓ4gK]½³©&àÎ{B14ϺþÎZ7$ÆØøø+žýhÞ„3WíÐço5Ú{z§_Ïãã~øð ¹=>‡F‰š0E£¿}ñòê`ËU­/×4ˆ(„~ç•ö Kþ‹—Ï º³ç©HOK̼î¼%‹_øÍmËv>ûë´ûn½ô³ä„¸ƒ'¹Û-))æË†²ƒˆˆˆˆˆˆˆˆˆˆˆˆh¨q˜ˆ"’ÕÚS 5ÐZË¶× ¤ßïÓ8‰áð3÷4"Èn²S.¸¶ÒeŠÑ8iÄŠ"–Þõ»EgÜó×2”¶“µ«´êÌ·þÖºñ“ýŸi¨‘[.Y¹@O µc»>òhœC2ŽïÛ–èvs|ly’96e¸Îk0(Æ‹—ÏYðÁ㿘ùƒ+ÎÞ.„è z°¿Bpgr"""""""""""" Yæ#¢ˆ$¥”ØpÍïÏ­ßð·Z7Ñ©q¯®u6WŸpÑ ´Î½ñÇÜ•OCãž9óªg>”Ñ I'ÝáÒïWs|zý‚óîüãîêÆ¶*-ú†[”Ñ`ÊHM ¸Ÿµ¡Æ¬uQ(èílo÷ôöL´¶bΤàÃuCH!n¼dù’~sk—¢3Èa’“Í×kÑCDDDDDDDDDDDDt*8ÌGDK…òD°µÆÍ¯Îu6­Ñ²‡NÍ‘g~Õ ÀhmÊù×6š¢¹+ŸÆ¤Táó¸ãû{|‡Å¾ðê_þ=÷‡¿ð¡Óå ûKïdê t»×Õ§u Q(8òáGäûŠ+¿±0GË–)ù9ã¼ãÊz2È!·hÙCDDDDDDDDDDDD4æ#¢ˆeµZ÷Càí Ëñþø(·¥5à¥x)48Uõ¶ÖÜ•O(Æ–y7Þ½P릑ÎÚP]÷ú-çzü×øÞ5zWiÕY§ßö?޾µ}‡”2Ø MÈ3˜ŒÛ¥ä×U4"Õl~3àÀµAǧæç ôsÅ {ZÑœ±Y£v\”X”–—«qQ¿ð‡ÎDÑ„¿à ´&ý¾Ü½÷_îíi¬<¢qõÓágîí@çªi—|«’»òi«nÛû%ëï¸$IõûNyFUÕ¬G^Ý´xå­¿;ôÉÆ²O+ ÍÑn7˜¢îØGÉTŸÏkk®ŸhmjaŽn—×þÕwV%Y~¿ñ\Mcˆˆˆˆˆˆˆˆˆˆˆˆˆú‰Ã|DѺºìex غTý9þð팮²­ÅZvÑÉu×–vu4ÜyO1šæ\wåÓŠ”rÛßîûdË~wCkפ@k‰Ù¹ÝZ÷é­~÷G%2)ÐÚågÌ×íÒÓs§Œ›b4(Ö„3µî!""""""""""""êóQijØìØô)S*žøyQùc?ùPõyÜÚ•Q_ŽmZÀû´’›=ª>à‚*ghœBDDDDDDDDDDDDÔ/æ#¢ˆ'¥ôA.‡Dy‡,;ÏÚõ³sŽÙ*Ò,޲Ù[îîj›hM1ͺúûÜ•Oeo<»ýµ›Ï48»:f÷uœb0Ö-¾ó1%.-+2¦.™}ú/_›³àÜOÑŸÝê ‡êŽ/»ì§d_óËG·­o©’7`ˆU7´Ô¾úáîYA–=“¿qEÀûˆ"–”²ãpñø@K™iIeÑQ&Ývæ€I¹YÞ€ Ò4N!""""""""""""êóш`±X¬Øu²ãT¯kBé#ß›Tüû›¶zlíZ´Ñ×U®¾ßlmÖµ?hPŒÆ(-{FK}Õ±W®_~`ÏóZlº/DÅÄ—,¹ûéÄØ”ÌÑ_¾]1£&_pûŠ¥?}ÆmÎÊß@íÇ©U­K®»ç±¼Kïþß];<˜·c(5wXßpß“B¹œhZáÔÏbÓÒÓµî"ÒSËáâêß?:ÐÚ9§Mïպ竒ÍñwË”2Aã""""""""""""¢~á0III|»Ÿ‡+öÆÃË>ûÕE1•ÿ|ðcÕëvgý'Kùöww×Ü@kŠ)ªfÆå7qW¾aÒké°lºçæíëpIV¯µ³ÏÝø !#g벟>?1:>9èÀ_Œ95}á÷þ²tþ-¨3%¤ìgŠÒØÖuÚw|zòY·ÿ¾øÙŸnwy¼º}î­¨+¿ôîG„ÇçÏ rˆºì®ßh"Šd‡Ö¯n ²$¯øÆÂ šÆ ˆzÉp"""""""""""""=õ "Ò‚ê€àç^ËÆHbkªkØñ·{k[Ê÷ÍÄÊþÜG1kçÜü[OrÎäS^Ëš¹bVæô¥¾ÚO_ÝVûÉã¤êÏíÏý¤„¹¼¦qÙ÷z&“±zÁ´Â†«¿±0mÞÔ‚©EÒ]¶Ú-ö¶G^z·âýÝå3¤”}¾¦¸ø²3~ùÈâ¡Wï´¶òÓÚÊà³§„hIJ©M+œf3{qô˜ÙKF'ç@ˆˆþ·²Ôi¨Ú¼ñXý–ƉRÊ¢>†¢3wáEåã¿qÃ"Å0|»Ï)£iÂÙ7,ÖõjóÁÍŸU}ðB”§Ç6ûTOUÕ,«Ã™eu8QÛØ6”©€„Œœ?òÚäè„Ä”!p¢0à¶Û,½¶®éÖæM-¨§mÑ×=úòÝA–Ô(/‡ùˆˆˆˆˆˆˆˆˆˆˆˆ(4q˜ˆ"Vrr™ÊÃ}¿Ë5¥ê¥ß£úµ?Õäž}Ãñœo|{bŒ2 ÷y#Mëö žay`)3]̦֮}[Ñ´oë‰Û„°›âÍuæÌkjÁ$5còì˜Q§g$çä„ã%S=={ÇÑ’ú¦’]–¦=Ÿ(–c5ã!Õ\§x©Ké5qþgEßüÉ\ctÜò!íƒB=ûÌ£gŸ [Ñ£û_¸ßï÷¸¦huþ“BX§_vSÉÜïÖìï„(}ïµ K­]{ÞâTs¾¦±­«éhCK°áð­mG‹¦ADDDDDDDDDDDDDýÄa>"Š`ʃZžMz½ÇÞy¦àػϵf.¾èpÁ¥?˜cˆI0kÙÎǫƺ](†&@¤ªf Ùɤ4{ÝÓ;謮@åë¾XñMQµ¦„䮨”tgbv®?iìxcJna\RN~rBæèŒ¨Ø„„!ë¿ÇÕkomnw´5Z­Ç*-eû<]µGâz»ÚG«~_€ï¼÷5BØÒ'-80å¢ïM7%$¯ôã ‚Çaq†Ð Ÿ'cêœ]gÞó×¢s ùhÄ;òÁë†@·+Ѷ ¨pšÖ=_õÇ_8 às ^ж†ˆˆˆˆˆˆˆˆˆˆˆˆ¨ÿ8ÌGD)5)é|!°H—“K5³uû†ÌÖí퉅Ó7^ó‹±ñYùãui {W§ôzó­;í‚êñçܼÜÞRSÕ´ïý¦öûãÝÝ]S7 )Q>7ß×ÕžïìjGgu€M_=Æ ¡XQ&›!:ÎcNì5%&{câÍÒ'Ñq"61Y(1Ñ"ÆœúϳÆèÅ—ð¯ÿœ–·¿×©z\=j¯­KU{{¥Çiî^‡ð9{WwW¬ÛaOR½žQRÊdc?ÿ3¤ƒ±>{ÖéuιiŽ1:nåP?þ@Y)~å…¢ƒâ  mÜÂ3*N»õ¿§Ä¥erˆ€ßëqÛ[OÉsD"Cë¦/{íÃ]»¨Ó IDATZ:>ÿK Åfµ¿¤uQq˜ˆ"’Tä¯ u¯0wW—œqà·×ª¦äQ»óV݉Œ¹g-€Bï²PÓYüI%€Ó­eN_‘欂ñ“/¸}üä n‡ßëquTîÝßRòi·µ¾<Åë´O«Qn4¤šéó¸2}Üö.à¸FgZ¬gß`Êšºh6„§w8;[ö>ýßÉ€ º«åœëôÙ´K®ŸÕ°ç“²ê-ov·VìOõôØ å`wMT ¦˜šôIEME—|;)gþŠéB1¬äcE”ºm›J åü@kWŸ½H«Ïío©yxõ;}íèù')¥K³ """"""""""""¢â0EœÔÔÄ¥X¬wÇ—(kÇ£Ï߇ʬÉ<í†qçwz”9%Uï°Pѱÿ#w Û…欂¯Þnˆ2ÅdN]<'sê‰fÕïóXkËJ›Ë¶uuÕŒswwBJþý~/:1åà˜¹ç¸ÆžvátcLBÀJ½xº;Zw=ú?¤šì˜Üg|<óÊ[W@þÒsçä/=÷Ä‚”ÒÚX[ß\¼«¹³¦Âm?^/mÍ1žžîxÕ¯F©>o,@Â`ðMÑ=¦øDg¨ wÚ„"dNkΚ±° :Þ<wÒ$ ¢|Ãê Ÿ¯a?sáÔéZ÷|¡¦©­î[÷=– ¥L rH•Ífÿ»¦QDDDDDDDDDDDDDÄa>"Š8ÒŸ#D÷¾“^oAËÖu-[×¹c3ÇnÏ¿ø{1©3VÌé»õ9ŽJt{lrF±ðd÷W FSêøYÓSÇÏú×mÎΖ†Îê}M•\ÝÇ«==¶|H™2„ÙaADEÕ$eoÌš±Â”9mÉä¨8ó<½›éµ´ßù·;¼ªß—ì˜Äì±;Ϻç¯/w+„HÎ-—œ[; E"ÕçóvÕžhm\Ö¨2S”Q—ËÛo=x¸ôî?½”­J9*È!RB¹ƒ»òQ¨ã0E”´´¸\ãzwôCtoë±%ÿ÷K(c]Úì3êÆ]x넘QcÆè¦5Õãêñ»œ­¥Ì:åÁ‹¸´¬Ü¸´ rsüûÝÁÙÕÜh?v¤¹£¾Ìé8^cèµ¶¤ø\½¹€L<Õó„¯0F×'¤nÉ,Z"²f­“8ª@Þa}ééh:¶ëÑ; RõÄ3ÅÆ•¯úûºBѲˆþ­nçÅRÊ€Á—ž1Oó‹Ûû|~ïO­ÛþîŽâ%¢‚(å#V›õ}íʈˆˆˆˆˆˆˆˆˆˆˆˆN ‡ùˆ(¢H¿ñza5ì£ú}yí{ßÏkßû¾4Æ&ìÏ\v™%÷ìoÍ6ÆšGÄeb-‡v0'ÐZö´%iCy®¸Ô윸ÔìœÌY+ÿãv·­­¹»¥¶ÕÞRïèi«÷9Ú£<ö®ŸË™.U&ÃPv ¿P ­¦„ÄãæÌ‚ž”ü"còØ©éæÑ…yŠÁV—‰u´ÔÖìzâîø¾.­«£ê/}ü­QSL¼–mDôŸ*Ö®î ²ä¹xùÜ"-[ÞÙvpÏïžÙâöúVöy Ä6k·ãçÚT ‡ùˆ(ÒÜ wÀ _¯cNÓû/ ñý\±£ÆìÊ^þM™µäâ†è¸ˆbjÛýnwàéLÌ›pǾ¡”‘ž”‘>éëWô•ªê÷:¬mnG§µ§³Åî²´ºz­í>—¥^Wâs9>3Úïó˜TŸ/ª?NJ5îÄeSßgv!à†¢ô*бW ^Å`r+QÑÞ¨è8oLZ¶/>y”›>6&.%;!6-3-:!u”P”ÑFý߆v:Žì).~ùr!eÐÁUÅ`¬¿üÉ· qi™A‡ýˆhø©~Ÿ¿£º|j 5!„ý±×?,[>kRêì)y…1¦¨¸áhpy¼ÎWßÛuà¹7?Nvôzæ÷ã.{…Áx¡”Ò;=DDDDDDDDDDDDDCÃ|D1ÒÒÌÓ$Ä$½;†‚b\M§Õ®}µkéN½sôŠË}ÙKWÍR¢ãÌz÷ %Ë¡Ý/-—tT1giÝóUBQ ¦ÄÔ Sbj†yô„A=–Ïå°IUJCT´I‰2}1ìbþüψR½ùŸ[k?y}!úvT Æc—?õ¶!!cLކiDÀ±ÝK©ÜEUJ™öÚ»W¼öÁnðÅšLGÇf§µM-ãŸ7¹ aƤÜÑÙiÉÙ§rÞæNkó»[V¿»£D©knŸ.%–ô󮟩Rœgëê²Êy‰ˆˆˆˆˆˆˆˆˆˆˆˆôÀa>"ŠÒåz7 “Xw×ñEµëþ†šus™’FíÍœwŽ3kÅåãcR³Ãzg6gs]êuœË˜47â0Œ1 Iz7èMªªÿÀ ¿ÞÕU[²¬¯ãƒ¡ñÒÇß 9ä# ek_pôóPc¯Ç3ñH}óÄ#õÍX·y/@t›ŒÆ6s|Œ-ÕœÐ;:3Õ—-¾zg‡Ó-·vÛ­Žx«£gœ”2ÀÀ…|ÑjuÜ"¥t è~DDDDDDDDDDDDD:ã0Ee õŽVˆñÚ:æ5~ô"?zQL1‡S¦-nÉZr騤‰s¦E1èÝ8µ­h-gñ¥ã4¡açvX;>{â®z·ÝÚçÎZŠÑxì²ÇßRÌY¹¹ZµQpRõû;*K¦ ê1$Ý^_¢Ûê@‡Õ£ -C•÷e!å/º¬öG‡ãÁ‰ˆˆˆˆˆˆˆˆˆˆˆˆ†‡ùˆ(bH¨‹€¯mòÉ„ßãšÜq`ó䎛›)9½,}ÖŠžŒÅ«FÇ.(Ò;°/Òï÷Y*vL ´&ŒÑU挱ãµn¢áÓZºí@é©Îíë8S‚¹ô²ÇÞÊŽM5J«6"ê[ÃÞOJ¤ªÎÖ»£/xGÃ÷»¬–z½[ˆˆˆˆˆˆˆˆˆˆˆˆˆN‡ùˆ("!ŒÉIæ1zwèK&y¬mKš>~ M¿¡šb³ÇÕ¥M[¦¦Ï9cL\ÎĽ ¿¬u÷ÛûàWZËœ¶¨‡ù"€êózK^yhoGåþÓp’iÛÄìq;Vý}Ýlƒ):V£<"ꇲ7žíÖ»¡ŸBÈ{­û§z‡ ‡ùˆ("¤¦Æf©~„Õ%f‡›TýcœM5cœM5hx5 DgLjVUòÔÓ\iEKR’&Ì-TL1ñzõÕ¿û´dÉ?þÌk'hCÃÂ~¼òèÞgïQý^÷¢“;fö’¾ñÀ“§Cˆ`ïD¤©ªjë¡âÉA–Ë$Ä6u! ¦C»ï-ì^U¤x¦ÓfÛ¥Ñ9‰ˆˆˆˆˆˆˆˆˆˆˆˆ†‡ùˆ("øý¦LUïŒÐ&eš«³9­eë:´l]~]—9®%eÒ\5yêâ4sÞÔn±u?¹¹"Öi3ÏV…˜ !g˜ Èi€ˆªïJÈ÷L¦ØMmmmŽ!z\""""""""""""¢Áa>"ŠBxUHnÌ7@éuïi<:¾§ñ(?zTa0‹NNoŽÏÔ›4avtℹYñYùc…¢ Ù_pÃûÿ8 `y µ¼¥—»†ê<¤½ãû?Ü{èÍÇGKÕ¿ødÇ£¢jÏ}èyOú¤YË´h#¢+}ãiK%¿WÅë_üOCƒì°ãó?!„!11±@j Æ2O@Id$â  €'vÜë‘@£hdµ0ˆb!|;; ÃöF )ô. PÀ÷""""""""|d8ÌGDAñ)n•³|CA‘~ßXWgóXWg3:~üùÍÒiˆ‰¯Ëב?CMž8;)©pÖ8cBRÊ©œ¤uï»ï'DGÎüóæžb;éÈÑR[sðÅßu¸º;ôçø”qã·_øÇ—fcãu»Ô3õMªªÚZ~pbàU±ÅnïnëóþRúT~þ‡(2H‘¤wé+kæ9ñ±ˆÒ»ƒÂB¢Þ’õ ¢¾¹MHЮwQ¸€1kæ9ñ-Å›zôn!I~ïC}âkªÉ×H„Ã|DüC//³;œDœßåœb?vöc‡ÐüÉ«'nUŒÍ&sâqÓ¨œžø1dBÎĘ„Ü i†s´1>Ѭ(ŠÁ“`ö96O×ñNËÑý­–òí=à®|iùEåŠÁ¸BÓ7¥×ÚÚ\úÊÃ5ÝÍÕ§(8ù=¤sú•·í˜wýÎö8"”¦ý[K¥ê›hMB®Ñº‡(H`‰Þ ¤¯_Ür©w…‡¹cæ]×´w£SïÒǸ¦(0éÝAD}óz=K¬Ó»ƒˆˆ(œÅúc—ؤwéG ,_.¡à2 §\6±úÐÚ£z‡>„x@É›Æ×Ö‚Ã|Dl6[cr’Ù~^Ó”T}Ùn[W¶ÛÖ{uI_‡&}þ§Ïa¯IçÝ)%oùˆÂÀÁWŸ²Yòª*ÖjCaK x#lîieáÔUgUW¬ÿPïÒÞ¼y·E !ï×»#RI©z…PôÎJf“Kù/|Ÿ¡5ê~@FÔNžFé÷êÝ@!!¢ÞÄ=&œÿNeå;n½[ˆˆhäö| àþyónÛ¼wï“‘övQ?ä­:@ÀM<èÔ¸Ò¼ÆnÞCJUü¸JïÒG^Ñþ[!Å8½;†’†aý ¨WI‰hä’RzÔèÝAƒ¢¶Ú“ÐkiÕ»ƒúÐÓÑܰõ½;>yøÆè®ê⥈9é„°Nÿæw?ºöÅmÓSò&ä % –ê÷ù:Ž”L ´&º»»;µn¢ð¤HDÜEU¯æO¿t¥Þ¤­¢¢+:\­/X¨wK¤’"ò>_¸§ è’!„Þ!¤!Pò‹VýwêÝ2ÔÜFµWï .½†””s|Ѧ5&œÏË£‘–"ëùXÔén{1kæ9ñz‡¶òŠ.9ÀËzwD𯝹€ÛëPˆ+ó‹.ù˼y·EÔ/½ÑÉå]z ¤ø_½;†œ_Öçrî`ED‘Câ&êA§L©Þüª7¿sÎ_ÆÄyUcæ S|êd½Ã°V—”U¼ù˜Ãii™ ·¿÷KÌÊÝzÞC«'ÄÊlw$7`X¯i …5U 7Â.³û! „Dæ›GÿÂ`íxÕ^#óõ Ì0WòjDˆÐç=€Œ÷qg>‚œz7 “ @þ’OÿÞÅpó…88#ôù&[Bþ„ϧ‘Oðyø9qÃ|€Æø_#„ˆèI_Ÿ"‡õ5óQÄB}RyXïRG{ãø£›žÇÑMÏ#>=Ç•Y´4&}Ò˜³ò?ÿn†””²ãèÞ’ê^²Û[kgÐoKÅØ2ýŠïsí÷ ÅX8L•D4Œ|·ËR_9#à¢Äû«Åªq…1!Ѧw…‡Œ„¬öNW« @Ñ»…ˆ²7íÝ©C\4hã݈ˆˆGB¶q33":‰6)zGQ`&¡´çãs˜ˆ"†ÕÚS’”dnáå¯"WO{cLÍ–WP³åD'$#¥`R f"µp&bG龤„娡£Çvnhî<º/_õûfø1„è.XzÎÎ%?úíctìòa¨$"T~¸î¤\hM@ð»4 ¢…¿eIDý±wï“Þ¼¢U]àöD!‰—‰¦TˆfÁ¯ïˆˆˆG*ͽ¥3 …f“ôŽ ¢€d´ß0¬p˜ˆ"†”R¦$™_wéÝBÃÏí°¢¥äS´”| ˆO”‚™H;Éc§€—äí›ßç‘]•ûÖï~»ËV(_ª¾‰&ø„°åÌ[~`é~;#6)õœ¡/%"­•¯>د÷L¦šÆPøj $ÓœˆúGœx¡šÃ|D¡HÈf½(DÙÂY>""¢Á-|:%¢“à/T…,ÙQV¶Æ3œgà0EEÅÿ©p˜/œaÏú¾¦ýŸ¨>ߨS}˜žÎãèé<ŽÆ=ï¢ÍiH7I¹“‘”; ™¹PŒÑC–n¤ª¢»©ÒÝZ¾­±ãÈ^OWËXyʿѣ Ç'žsEå¼<7*6a妑Ž<=»½¹!Øîœï´··Û5 ¢°'!ŽéÝ@DaD ÓõÎ ¢$õN Ð @ãÎËDDDƒ#¾^BD'S¯w!”†á>‡ùˆ(¢tÚí)I‰Û °TïꟜY‹÷ŸuïßV@sپڲמj=^¼3_õû3ó¸n{'Zʶ¡¥l@( âÒF#!+æÌ<˜³ò‘•‡hsêà߈äî©ÖšÒŽÎÚR—³óxšTý± àT÷H2FÇ™uõ÷Û§­ºa¡b4Žºb" ‡Þ~©À’€‹B¾¤m E!•êÞD&„‡%äùzwÑ×I ½(4ø*Îò Š*øtJD}“‡Ný§yD4œä¡á>‡ùˆ(âˆßKÈ·ôî ~ñ,þáÿÚ.»hn~vÑ“ùRUÕÆ=,~ý™îŽ£eã¥êôИTUô´7¢§½­¥[ÿu»). YùHȇØäLĦd"6%1É0˜b{Úaçíu §½޶c'þÛ\ç²·ÖJŸ«ç‹øÁ_¦L{zaÑÁÙ7Ü™¢¤æ#"¨.›ímîÎR §|Ÿžýµ'¡(JîÂ3få.<Ðv´¸æð›/55îùx”»Ç1 Cø‚±ÇÙ®šbtÕm-*ÎŒ˜¤tÄ&g &9±ÉˆNL…16Q1 0Æ$ *6Ƙø¡ÊùÕç…··Þžn¸{ºáí±Áí°ÀekCoW\¶v¸¬mð¹_½kÌ%Èèxó¾i«n´]vã|ƒ)fÙ=.…¨^[WWog[°Kì¾!¥tiD¡¶v‹+¿hU>ß–ˆ¨/ª@…àöD!ÊÀa>4íÝèÌ/º¤z·…«ÆCk;ó‹VµÈÒ»…ˆBS,bõÀÃ_¢! MÃ~õóQDü·Tñ©ÞÔ'¹ì‡f÷çÀŒ‰3 2îžYÎÎö¶#ï­9Z»ý=Åv¼®~uP—ãí‹×i‡×i‡½¹æ¤ÇFÅ&Àsb°/*Ö Åh‚!ÊtÒûù\ø½¨>¼½¨>/üž^øÜ½Cñ& ”4ÅÅWŒ[|vûô+n+L=vžD¤ŠuÏ•XhM‚—Ø¥AÙ óQ?øÍ¦ÆnÀÉ¿&"-YëËfÑ;‚BÊnæ#""±—è]AD¡©¬l#¿hU9€éz·ÑŠR {†ûæ#¢ˆÔÕÕ½5%)ñyܨw ›œv µpÊœÞ/.-=cöuwd̾æpMå– Çvo19Z'J)S‡<¶¼½x{zœzP$àŠOU^¸â¢ž)]71>={šÞMD¤#›ÞH ²Ôd³9>Ö²…"‹¶ àZ½;ˆ(ô5ìXÓ›_´ê€…z·Ñ—ÉRÞÏëÀÑ¿ e+¤¼Fï ""¢p&%¶ óQPØ*9ÌGjê*‹ßhî“p˜ˆ"—bø1¤ÿ,9z§Ð×ͼêvÏP@&EI+˜:>­`êøÙלøÚÖçîíi)Ù]^·k³µíÈAcOËñtŸ»7#äÒ]BkLꨪÌ)scœŸ3wÙøèÄä)¦èÝFD¡ãÀ ­Báw …—Ø¥A©-ŸYš7m; Òõn!¢Ð'`øXB½[ï"ú7)åÇz7Ph©>4³"oÚYz·…«Útëþ¼ö$ €½[ˆ(4 -òÄoI+z·Ñç$>Öâ4æ#¢ˆf³Ù6¥$™ïƒêÝBÿ–·è¬CŠÑ¸R‹s£cãs毜™3ÿß§Sý>Ÿµ®²º¥|o[Kéwgí‘X—¥#Åçqd¼]Ã@UŒQ ñi™Í©…“]YÓæEgÍX˜•:nB„˜§w…0)eíÎò.‡¬Vë~‹(ÂHy¿š_´êm7êÝBD¡Ïc6~dìöôׯˉ"Œl?65z—ÞZ¤¼_ÍŸ¾j$nÓ»…ˆˆ(\É-[|yE«ÞÀõz·Qhª)ÝÐZP´j§–èÝBDUÂû–'â0E)^Ñ4†"^}ùÆp·G"êÊ3z7tŠ*žÓ»B[Mé† ìÕ»ƒˆˆ(œÕ—mØ \ï" ]^sÔë,zwdBÓÍ 8ÌGD#Ž”Òm±uBÜ@Õ»g$™rþ5üû&" .»¥ËÑÞ<+Èò›ÝÝÝšQÄ“RJüUï" Õek·q@„HOâƒªŠµezWPèÖ»ˆˆ(ÜI!þ¢w…®†kz…Äãzw` ©1™´«ÿG÷ IDATâ½ÀQ½;ˆˆˆÂ™Ïý¸;õÉó{v½+ˆF!äZŠ7i>çÀa>"¢ÏY­ö×ü*¦Bb5UïžH3å¢k'èÝ@DD'ìýç_]Täj Shª/yËà·zwQx¨-›õ¢öéÝA4rÈOjÊÖóê4 eek<ògzw…³†kzù ½;ˆ(tÕ–½Ý"~¯wÑH!F¯9úïzœ›Ã|DD_ÒÝÝÝa±uß¡ÎÀÛz÷D Å]Ÿ>:[ï""¤ê÷7íÛ>)Èr‰Åâ8¨iHñ0=**½;ˆ(ôIy¿*…ø©ÞD#„TøñF§¤¦lÃF@| wQ8«+ßøª„ܦw….á°ý@ÞD#"ä/O Ûëpn=NJDê,G±ÅÒ}¡"Å"ëÝî2&MoлˆˆN¨ÝúÞ©úX pW>ÒDYÙ*‡ˆ¨_jK×} ˆµzw««K7ìÕ;‚—*ý?àÖ»ƒˆˆ(\I)%Tù# øôn!¢ÐT[»ÅUüXï¢`gmÙœ—ô:9‡ùˆˆúÐi³í²X»O*.Àn½{ÂUáÇêÝ@DD'ìñoÁ^ ôy½ê‹šÆÐˆV[¾qƒÞлƒˆÂƒT•°èÝA©$Ða4™~®w…·úòåâwzw…³ºŠûðÿôî ¢ÐU[±n=¤\£wQó¨Rý®”÷«zp˜ˆ¨ºº»ßê²vŸ&¡ÌƒÀ?Àߊ_Þ¢oLÖ;‚ˆˆ—ÝÒeon˜pQà=‡ÃѪqp^çQ¿ÔU¼Ñ ‰_êÝA©ä+÷¯i×»ƒÂ_mºõ!±_ï""¢pft{~  Bï" ]&Ÿ¼@›ÞDIˆ‡êË7–ë™Àa>"¢°Z­û,–î  ñg]z7…:SlüQS¼9Qï""üó異­I)_Ð8‡µeo·’—… ¢~©«Øð >лƒ(ÒàÚ² ¯èÝA‘AnÙâRÜ À¥w Q¸ª¬|Ç ‰›xõn!¢ÐtäÈÆÜ¡wQ:`t¹Ò;‚Ã|DD§Àb±ÔYlÝw[möÑò*¹ €nÛ¬†²ÔÂ)z7Ñ •®Ï ²Ôi³96jCô¹Ú² «…/ñLD'%¥”ƒø€½[ˆ"…=~å6½;(²Ô”­-–ÀOõî "" gµåëw ÿÖ»ƒˆBWMÙú×!ñ„ÞD¤Šzmeå;n½C8ÌGD4RJ·Õj_ÓeµŸ«|yò>Õzw…’± NÒ»ˆˆ€ãw•ù=®Â€‹RþSJ©û7'4rÅIÓíŽêÝAD¡¯ªxm€HSˆ" È­íÔ;„"O]ÙúG±Vï""¢pV[¶áO€Ø w…0§íÇêA ¤wÖ–l<¬wÀa>"¢!ÓÙél°ØìZ¬Ýãƒ,‚H $>ÙëiìigÕ»ˆˆ€½«ÿb º¨Èç4L!úš²²5@\ Gï" }µeë7AˆßèÝAî¤?¯+Û°EïŠ\Ñ>õF)P¦wQ¸’RJUñÞþ$Q[»Å% êe€l×»…(¼‰Õu¥žÕ»â æ#"ör‹Åök«µ{ŠPÔ™Ÿÿ ©Tï.Í”6svν3ˆˆF:Ÿ»·§³º|VåÏ,G±¦ADÔ–­;(…¼Üm‹ˆú¡®lý^Õ»ƒ(\Iàuåëþ¤wE¶Ã‡7Ø¡¨ó‹DDD§®¾ä-‹ªÏ“@‡Þ-DšêŠ7Ö ).àÑ»…(Lí0ºÝ·éñeæ#"f]]Ž‹Åv¿ÅÚ=Ãç—™òJ< Iï¶áŸ2ª^ï""J×>wRš­Ià­{ˆ‚©+ݰR> w…>)¥ì5ô~Bì×»…( íˆr{¾«w uÅk!”+Á,²ú’×kõ*ðù”ˆ‚¨)_¿U@üPï¢0Tg0(—VV¾ãÖ;äË8ÌGD#ƒßKp³H×3Ãn··Y­ö׺¬Ý·Ymö\ÿüÀZH׳m8¤æOqêÝ@DD@ÅÆ$Yê1¢^Ñ4†è$jË7<ˆGôî ¢Ð×R¼©Çu.€Cz·…‘rL…Ú‹ÔÙjK×} U\%ŸÞ-DDD᪮lãf>ŸQ_jÊÖ=% ïÕ»ƒ(|Èv(êyUÅkÛô.ù*óQd»EŒÃíâ¿p+Ž@Å[ˆ‚Uï¤/H)¥ÅÒsÀbéþ“ÅÚ}¹ÅÖ=F1øÆJÈ« åÿB`;›Þƒ‘=}~”Þ DD#]W]eÇÑ==à¢ÄkÝ'T]ù¬Ÿà )Tåþ5íFƒáêôn! Ç ç5”­éÒ;„FžÚŠuë!n êÝBDD®ø|JD'SS¶á·þ wQè–ª<·¶dãa½Cá0Ež+…·ŠËq›Ø ê ñ{¼…'¥Wï¼¾tv:¬Vû‹Íþc‹¥{©ÅÚ¬|cU)ΕÀÏ<'€ÝhÑ»õs>Á3§ÏÏа…ˆˆØ÷Üÿ« ­ /±K¡IÊûÕx˜¾ `Þ-Dú*‹ßh4@ž  ^ï¢V§*ÆÓ«J×7èB#WMéº@âûàÑ)«)]÷ÜÀ¯w …¦ºò ¿ä•Oˆúd“çÕUlܯwH0F½ˆˆ†Ì"~\‹üÀ¤¯­K¬×>jð:; lúòíBˆ˜4³9ϧ(cuœ"WJ¤+@š„LˆT@¦ˆÀ)}ì,Ÿÿ·KB6 ‰6!D«â8 ÖÆ:«ÕÚ”’˜xµòÇ›š7yÜ©½ÕDD4TŸÏÛX¼cj 5±XìÛµn"꯲²5qå•WäUxžp£Þ=DÚªÊ6TåM½| þ÷…D‘Þ=D!æˆAàìÚ’×9ÈGº«-_ÿd~Ñ%6@¼€Wt "":5eëŸÎŸvq;„ò €½{ˆ(´H)%€»ò‹V5ø½Þ=D!¦ çו­ß­wH_8ÌGDáïXüÀÝ’‚Õ w5¬vRJ€ÃŸÿé—´´´D¿ßoPUUQ% ü~¿Ý`0øÀl¶¹dï@:R’Ìù6|RŒQMŠÑ˜7Ç""¢¡UõáÚýð« ­©À3ŸSO²äš5~!øNÞ´ý½€øžÞ=DÚê*Þhž0çÊ3¼Ï;˜§wQHb¿AçU¯mÓ;…è µe^)(ZÕ#WÄéÝCDDŽjË7nÈ/Zµ ÀkÌz÷Qè©-[ÿ‡ü¢KÝ€üxÕN"hB9§®lí!½CN†°D¾FÜ.nE,*üÁùà<)•…¬ÎÎÎn«Õjéîîî´Z­5V«µÆn··[­V‹Õjµ tïQèÖè¤ä —ß%""m|åñ`_ïû¢¼þjCtФ¼_­-Ûð}q›<±ƒ0QP•û×´‹Û2)%ŸçˆÞ1ºÜ§sBQMÙú7”ÅŽéÝBDD®þ?{wUyþüóÌ$a'," ˆàZw[+î_k«m]ÚZ´ûϪk­ÕÖV[ûUj·Z—º€Zk]‹ —º‚ˆ²›°É’ û’ÌLÈ6sîß _—™å䜙äýº.®«žû9Ïói¯’ÄäÎý”•¿õ"#%•…@v*+x‡y:KRUØY€-ˆ:_^’ý|Í|rU‘û†Ök‘Lã$í¹ËõNÿj÷L•Óðt{ÒŠÆ@€_¶m^»a[Åæ£2”Ÿß²mÛ†@m´ªtâx9;W|ã À.”•M©[½xòevƒ$¦Ð¢³º¹¼ôȳ–/‰Ï›ÈZ«JŸ[˜JÚq’›vrÕŠÅÏ•4z_’45ì,²Sùââ—œ"'‰_¤Age6!Ù»`䊒â5aGi.šùä–‹]?¹KzAÒÍ|k‘î·wÚ1Ug·oº‡}‡ïçðf?|û2IÑt5gîÁ€ã¾(/™ô¢RîpIï†@v33+[4i¬gîtI4°£3©’Óe¥Åט]Ï¿—#ë}¸tÒú¼úúeºK4`Ð*Ë–MÞZ> qºœ+)vÙgUés S‘£LöBØY€ ˜””scËõ½53'äÔ "šùäŽ"÷ Eõ¾¤µè=§¿µO 8ç $ JWë·÷]ŽØÉ<Ï[ýÎëû§¯º5±ªªW‚Mø§lÉÄò¼ú†“ù/€æX½hâÑhäÉþv ý¹Ù^$ïȲ’âÿ„h‰åË_ª/[T|…Ìû–¤Ê°ó‹lÊ”dYÉÄäÜW$­ ;€ì³vÉs«M>[Î~)©!ì<@;[í<ïIJ’‰7äâ/;ÒÌ û]躪ÈýS;¦ñ náÛëU 'Ú!$õë×mweø\R8hïžÇìT6õÅy^*•þs¦Ó?ÍŒßÐENûø¾Ò™âz»°bás›ËMþ†“+Wu£cj”ts埰úýgV…h­²E“'™!ibØYÈUe%§*ßl¼ø%HŸafVV2éÎHDÇHnvØy€v`’ï’´CËOžv˜Ö¢™@v»È R¦Jº°Uï›þ®»¬Þ×Lø˜ç¹™j½öÜ«OYÿgÞÿHf(yRäá@Ãí¨¬´ø•†®ÞA’n–”s¿] 8ff«J'Ž7/úÑ$‚ŽÄi–'QVZ|Mié&+ ç•/~vCYiñ·Ù(“¶†€\T6b¼¬tR‘óìI„@öYù~qIù€øÈSú¶…ðI™ÉûJY餢¥K'U‡¦-hæ½F»‘ÊÓ\IǶr‡Ê×ý~F§y^Þî™jÝû èdÀµ[¶ToZ{tºšI/Çb±ÕAgÚÓº9“·—•_qv¬¤éaçÝ>j‰x:]RIØy€6ØàäŠÊ*8auiñâ°Ã~[µhÒÓ)¨¿´A£*­°jñ¤·z¨àÐÍ:‰°óÈ.;®çžtgÔé “Ó<‘»¶É¹±ÉÞ‡”—N~3ì0~ ™@v*rß•ÓI{´a—tÕø Ÿ•vK[p®*Zе[Àq’æ·¹¬tÒÏœ³¡Ú1ù–¯'|ÊÊÒçf”•9¢È‰&ãg Èfõ’ÈÛ¯¬¤x̪’I›Âä7šùd—"w®LO)ÃD¡fsºR×[ÒŸPhB¿t#yù¦ërIù;¯Î÷’{g(ÿËÌ „ÄìzoUiñóe‹&}ÑÌ}]NÅ’øÿ?€´JK'4¬*-þwù€Ä!Îl”¤W%yaç$IÎÍ3é²*VV2ñ†5¥*ÃŽ„eUɤ÷ËK‹ì¢vˆdwHª;¹fUɤMe¥Å×ä —t-·ø¬•¥ÏÍ(/t–<ï8É="i{Ø™I2i…9ý> -+T´²tò‡agj/ya€¹oHúÚÚÈgú·ÆÙ˾dBÓ<ë.÷ùÇÑü‚úàÃæ=rW¦¿&E¸bŽ™™¤ÿJúïðßØÃ¹¼yæF9éh)ÝW1:3›2%)éiIOïsðwö6—üqç9Óˆ°³¡ÓY#ÓD“ý«¼´x~Øa€l³já¤e’®Üÿ¯_“ì’ÿ-“~àä¾"©kØÙÈËçMØ"é&çÜÍ{rö)N‘ŸHv–2 qÐù”-ž› OIEÜ9Nvª¤ƒÅ/CÞÊ9’î“tßþGÐØØx¶3}]òN”Ü€°ó¡ÃII*‘4EÎM.?(ºM˜ ;TÐhæ¾"·›¤I’z´q'“t©4®œ ˆ³ô“ùòºuç;ØÜGn/ÑŽZ}ŽIÈjË>»VÒ½’îuι¡ŸupÔ¹‘’;\ÒlÇ7«÷ 7%€lñaÉÄ•’n“t›;õÔ¼}¶ö>“i¦éà3ú†¹£ZrK%[äœ-–çfÙöÄì²²)uarYié„휮*Iûzî^IÏN’ÜÎÙ!’;HÒÞ¢!€ŒvN+}çxàÙ»5äéx§È—,¢ƒeîÉöS¹€NkçTχvþѰC¿}d#e:"";ȤC$íjHä’:IK%-u²E&7;¯¾áåË_ª ;XØhæ>ÓƒrÚÛ‡}þ®ñVìC"4“s®›¥›ÌWÐÅ !t^f¶rÊ ™>—nN$j&šÈ!;}ÑÎ?1bTÏÚTrE¼’í!ékκÊÔMRIÝÃÈ‹¬pµ¤HØ!Ž?Üš³óÏÇö1ª_Äêö”¢{D#Þž²HwÏÔË9ËsR/ã{p†É5:Y™KFœªåT+y#©èåkóŠ…Ïm;cG朷ړ»1ì~‰X¤4ì ¹jEIñIïü#IÚkä¨n‘oOy{:EöˆHýeêbκ‹¯ïÐAxJ½vDze“·jÇ0Ž¿·¸ÿþ_ï’ÊﶇM –40bnÀÇŸOͺʹn¡F˜¾$éÔ°C xå%Ï-‘´ä“ÏFŒÕs[¤nˆKºÝ-¼óû"Ý%ûèkî.¡„Eðœód–œ9)nRƒy¶9‰¬ó<·yõ’Ã6™]O_A|#@¸F»KätŽ;Í”Ó5>ìƒØñ/gŸ‰æq¨læk ¼dã‘éjNzØÌ¸þh¡Ó]>ØùøØðç^%šùðkJ'TJªÔgƒkeéä%ýoØ9ÖÌœP+iÕÎ? v^s¿zç@’´Ï¡çþÚŒf>ì°óûªKwþÐ |ó@x.s{ÈévZ/OçkœqµkÀœYú¦p¥™4ÿ±;3]Ëf)s@«ÐÌ ½@Pj••‰µåG¥¯º7‰Äò`h º-„£È*éûmÜ¥ANßÕý¶ÐHh9ç”öûˆ : t^sÿ}G‰¤.éj&ï€ãh%šù„åÏjÛÇ z9Òýöš_Ð žÒNæóŒ[v f¶êÍÉ{g¨nM$j&š@«ÑÌ x£Ý!’¾Ù†¶K:[÷ aË0™/Y[ÇçÀêYo,H%‡¥«9é!3«8€VÊ ;€N(¢+djíE¬ÛÑ×tŸ½ík&´VÚÉ|ÉÆºhÐA 3šóï;k3”,eî¡@Ãh&'Ö]™¾×†’ÚCïø–§3pΩÈ]Ö.{›¶¥{ì54ðùÚYm¢²²jíª£ÓWÝ›‰Dby°‰´Í‚ÕUçJêÙ† µV§ø”¦s­¯KºEEn7ß÷ŽX"Ýc/ÙÀäWhgó½³DR—t5“÷@Àq´Í|‚ÑwÚ¼‡Ó5Öññ«ù~!©»œ.÷ëHUº§ÉºÚ®þŸø¤Sžß+Cik"QShmF3 €àŒu™Nnó>NÇkƒnö!QÇw‰;HÒé’$S‘Æ:_'æ™Yúf¾†ú¶L_ìÂêw^Ÿï5Ô椇ͬ>èLÚ†f>ÁY¯#$õóe/Ó¯5Æ×U®›/ûuT].Éíü§Z·³±Ï7é›ù,•,ô÷À'½÷à- J–ôô` aø‚f>A:Õ×ÝL—h›k´û©Š\÷&׎u]âÒhw‰Æ¸v¸n6 ]èúHúñ§ž9ýÐçSâéšY¡—J&}> ©zãÚõ5›×“¡<¥ªªêƒ@ð…¯×-À.øÛÌ·Ã09=$é.v3å´JÒfIu’úÉ©¿¤/Èt¨"ê±ó7%ÝÝY²K.’>þïü‘³5Êh‚ešèÔ"‘ˆ[o^Ú’««ŠÅ»÷°›çþÏ»üõIƒÒÕLvoÀqø„f>A:¬÷î!—æ YK³ÒS—vÌ‘F¹¨úê²4•žê§/KzËcæX'hæsºAÉtåk:öd¾”~)©©f½B]ìã•Ǧòt«6®éØM“°Yüµ*SÍ“»'È,üG3€`8õ ;Â'tÜf¾"·›œ¾¿ËuN‡øu¤sZ–îùöØ–~ݶ-ë7$Ö¬úbºšI³‰Äì 3ðÍ|‚aêv„OèÈãFKê¶ËU ñë@ËÐÌ—j¨ê%“~ÙÛwþï’òÓÕ"æþpí€f>Á0õ ;Â'tÌf¾±.OÒ¥Í\½›_ÇFRZš¡”_½b_ç@gÕ°½ºfÝû³ŽL[tÚ«ªz&àHÚÍ|‚ɪf¾rѰCønƒ¾#5{âžoÍ|Iç25óiÓÒù›ý::«ÙÿúÛ\™¥pëLw›Y}Йøf>Á°f\ý¤¾p:Ÿ§Ë[°¶§_Ç&‰JIi›ö6¾ÿ^_ç@g䥒ɯ>·_†ò¶”¹qÐnhæ §dØ>%šÍ|~N¼Ä-§ã›½>"Ï·³%9Ù‚tÏ7-[à[Ó tFK&?þž—J N[töÐΆjÍ|‚Òv€Oq!7óv#ÕW“|Û/¢+Z´ÞSÊ·³%yŠÌN÷¼®rË>23?Ï€ÎdÁS÷f(¥Ì¢w@»¢™@P²«™¯!Äk/uÇËéeIßÐ÷õ6ïw‰(iT‹ÞñyR¢s–¶™ÏÌú%Ö¯^ëçYÐY|8ëõ Û·’¡<1¯ 4€vE3€`X–]³ i2ßwœ<½$©—$ÉtS›¯Ûê ©…ÿ}œ6µéÌÏÈkH½—©¶vö´5~žÅÌŒÍ8Ù4bîoAfÐþhæ”íaøS×ÀÏ,rÇÊô_I½?ñôPõÑOZ½çE®—LcZüžic«ÏLc˶mdZŸ®V>ëõìj䀰~þÌÒÚDå‘i‹¦‰Ä¬€#hg4óJeØ>Å<™o´û¢¤W$¦Ér›.qCZµoT¿‘Ô·Åï¹ôwmâônºÇË öý,èà¦ßqm}ÆbD·@@hæ”а|J“ù.qGËéU¥käÛ¡¯"z\¿p-k0,rûÊé×­ÊdZÖª÷šÜÓMM÷8ÕP¿ÏöŠ-›}?:¨ ß]´½rëÑÊ ãñêç  4óJçœÌw©;R½&©Ï.Vž¤z=¡Q.Ú¬}‹\¾¤J­jJ¬Ð8û°ï5ÉESS2•ʦ¿´Âï󠣚~Ç5¯¦7ÙffAæ šùÃeY3_“ùF»#äé55ÿÜo«^×¥®ékiw4ò=*é¤V&›ßÊ÷š‹m+•”vßÊi/¥ÚãLèh6.š·dÛÖÍÇd(—&5 qÕŸÁ IDAT 04ó†§aGø”öžÌw‰;LN¯Iêߢ÷œN‘§…*rÐ¥n÷4û~YÒÛ’Îou6Ó¬V¿ÛÔ¶fæä¦¥«U”->PL’€]š~û5U’\ºšÉþdf^À‘È̬ý‡¨@'’vDT«•Míí9™ïR7B½.i·VîÐ_Òåé:¹9r*“)OÒaŠè 6çsz¥Í{dàÉ^sÒw?WHy»o\4wñ#Ž9¸½Î€\·iñü¥5›×})]ͤ%‰DÍÓAg )&wHØ #a2€`ܯ-’jÃŽñ íÓÌ7Ú"OoHàÃn’FÊôí˜Ä×öF>)¦Aí3™O’òS/HJ;oñ󦽂°Ã”¿^±]¦òEÌý™©|€l²ß¡çî%霰søÉó·Í|‚±ãŠÕ5mءί(’ÚçšÝKÜArzCÒç¯ÇÍźޒíµù–mÛ6Hš®¶nîÛƒÚë\Èu«¦¿<·6^qT†òŠXUÕ @†ù­>)i‚¤nagñSÄyõagйqÍ.€ }(逬¯—4Y¦ñ’‘Ó¾%ñûšÝÑî€Wëôu_¿™žhÿ3ìy9÷¹k"“õµÄ×”}Øg¯áCÛ=äó<ïí»þÐ#ó÷G³ökÄ€áÃOíêõêÑ¡š²Ð>ò½hÿ”¹Óät­¤½ÃÎã?çïh!šùÇ´:ý埳LNãeú·ÆÙVIÒ×'ýå­­äçd¾Ñî9M‘”í“çÖ*®)í~JÄž—¹Ó•Þê¾U']} Í|ð ‹&þëdÝöã3”ßWU=h ŽëÑçêˆg ;²_JR3æ››œÑÌ T4óŽkòšÝI“d¯ôÆÎkyÿ§õ>QèO3ßÅn?Eõ¦²¿‘O’îÑKµ÷!±X;}z 4SËf½:ô$ÝÒÞ g¤æ>vç^™êž¹ß˜™d&:+OÚv[$ì:Óê4O?éjIƒ5ÎFi¼½þ¹F>IŠj½ÏiÚ~Ín‘ÛWQM‘4¸íqÚ]­¤;͹§Ò=ö÷Ù²dÁ²Àr@–{ï¡›ßñ’ÉLK§&‰W @§æ6†@çF3€à˜>ÜùŸ$=-Óé¯/h¼Ýöñuº™Th½äëE»m›ÌWä†JzMÒ_Ò´¿vù¿±¢){2Smþ“÷ò0HjØV]µì¥§ÍP6Säê@ÐÉyr›ÂΠs£™@pò´LÒoÔ¨!MNáKg‚5Hªð-‹S·V¿{©&é-IÃ}JÓÞj•ÔÍA¸µªj©dóÒÕÖ/|çàTccCy M½ùWóͬ_ºšÉý'Ï :Yo˧™@¨hæœûlÆÙ­ú§miåþ]µëµr2_‘*OoJÚÛ·,íÍé=d~_SÜŒsÝé›ç øà• iý ³¨X±xźùoÌPnÜï b¥¥j s£™@.ñ¯!͵¢™ï7Dқʉ|’´Aº1ŒƒS)="©>]máÆ²‡™½ò¿E5’ò3Ôï‹Çã«‚ @§·4ì@3€\²Á·L][´þ2·—"š"i_ß2Át¹î²ª0Ž®ªªÚêä&§«ÕÆ+ŽŒ¯]Up$È ï?óÀÛõÕ•Gd(ÇRæþh IËÂ4óÈ%­½ž÷óZ2™ïR7LIM‘´ŸoçÁé·gÃŒà)5>CɽwÿŸW²@ý¶êļÇî>0SÝÉ~WUUµ5ÈL@“ùdšùäÿšš;™ïgî@yš®\›È'Õ*¢Ÿ…"‘Øö†¤åéjëÞ÷èúmÕ‰€#@¨^ÿ㥠Íód(ω%j4$™¼Ùagšùäçc3_¤Í|cÜiJi†¤!¾§ßê^[v 33“Ò7¦˜õœóðm Ž¡Ù¸hޒ͋矡œ2EƘY*ÐP@’RùõÉ9a‡šùäÏÇkvm×ìŽq£eú¯¤Ý|;3(N/hœþvŒ8}@RuºÚò׊÷÷’ÉÆ€#@àÌK¥Þ;Æ$EÓ.pîþx<>7ØT`+]¾ü¥ª°SÍ|r‡ó±™O&ó]â†h´+–iœ¤|Ï J¹êõ#™YØA>‹Åâ2ûgºšyÉA‹&þë½ 3@ÐfÜù‡é µÛÎPÞ$E® 4ø˜)2=ì ÑÌ ·øwÍîg›ù.qCTänVDKåtŽç©AÒzØâaùœHÞ’’éJóžøÇ`/•L[€Ž`Ë’ËV¼9idÆ殎ÅbÙ÷±€NÂÉ{1ì ÑÌ —tñu2__¹_ªÈݦÑî]ET.é7’zøxF°L¿Ñ8{7ìéÄb±rIÅéj^²qØâ‰0@‡”jl¨ùº‹"’ 2,™¯ªz,ÈLàS¶i[ÕÔ°C€D3€\r—UiÇô9?ô‘ôwI¿’Ó—$E}Ú7,÷h¼Ývˆ¦˜"•”öúß¹ß=È<¦óèxÞ¸ñçï$êöÏP®zcYt5:{½¬lJ]Ø)@¢™@îI„ ë˜^Ò ý2ì»ÇçÉô|ºš—löþ3ÿÌÊ©‚ÐZkæN{Ýü·OʸÀÜ5[«ª– |†3=vøÍ|rMUØ|ð¾LS}Úkžòu¾®·œ˜jg.2V¦ó-xâž½“ õüÆ €!Y_»íÍ?]Q¨Œ_o»7âUUw |V¢±0?í@Í|rM®Oæ[¡ˆN—Ó&öZ-雺Çj|Ø+;§ó½®æ¥’Cfû3Óùt/]óÿæyÉÆ½3”ã‘hã…\¯ @Øì?kfN¨ ;|„f>¹¦:ìm°EN_×}¶Y¦ÍmÜk“¤34Î6ø,Hæ"7(Ãt¾å¯=wD]U¬"ØDà¯EÿõNÅŠÒ3.0÷‹ŠŠíkŒÒ0Óýag€O¢™@®ÉÕÉ|U2}U÷Ûòÿ¼² {m–tŠÆÙRr.Ï3Ù“i‹f…SnúeiÀ‘À7•唽÷ðm‡6±¤8–H<X –9½^¾hÒü°sÀ'ÑÌ ×T… ätžÆÛ‚ŸD5§•{Å%™«|q.ï÷’êÓÕ6–ÌY¹rI[š µ55Ïÿj”ɬgººISž. :ø¹Åi[ØZàZÝo÷g¬zz´™û¬•t²î³ù¾¤Êñx<æd×eªÏ쮃êªci›ý ›¬Ÿ?³téKOÌT7éX¢ú÷Af™Ø´U%“^Ûõ:Í|r‹©!ìÍô³›š\ñ€^”´dû”I:Eãl©_Á²I,Q3NÒœt53ë÷êÿŽYp$h‘ÚŠ-[^¹¾h€¤‚tu“6æ7¦¾cfGŸ—2Ï® ;dB3€\Sv€]rºWãlטÌL¦šX±X¨q¶Ò·lYÆÌ§WÓùºÞ’-~÷[+éÇ’’rzA=t†î²*ß3f±x<SD—fªW®\r¯?û^™ “E“y§lÆ+'g\`ZÉËÿ‘µ´¹´»³¬tÒËa§€]¡™@®É¾f>§7eú¶&XC«÷g/ÊÓ×dú¶n·ZÓåŒX¬j²œ=ž©>ó®ë¨Þ¸6ýô>ÈÆEó–¼÷à͇Kr–TzrçVTTtª¦l²Ø¢dï.¿ ;4Í|rK¶]³kz[Q£q¶½Í{=`ojœ5ú*g¥Rî “6¦«™YŸÉWŽÚj^*t.¤m[6l|ùwö•\÷ KMöÝD"±"Ð` “¸‹ÚwÖÌœÐ)ªÈ=4óÈ-N­Ÿ~ç¿Ê׺ÇjÂÒQTUUU8§"Ii¯¦l¨I6íÖ«§ Ô°­ºê¹1߬2/¹G¦5Nö‹x¼úÍ s€Œü0ÈL:'/•L=wéY¥õ5‰Ã3­q²Wâñª_™ 4ÁtWYiñßÃŽ-E3€\öÇ­ùJétÝk±sthkÖXm$jHªMW7³þÅ—».U_·=àh:3{áW¼S³yý›X57¯ Ûyf– ,h‚{¤|ñ$~ @N »)ZÆBÌ7[Q¦­2Ä FEEõ"“]™©ž¬¯=`Ò•£¸n@»yåO¯X¹ä„&–”å'½onÞ¼™+×Èÿ.?8ÿ"3³°ƒ@kÐÌ ·¸Ð>nÍVT_c"_°âñêq’þ™©žX³òøé·_;5¸D:‹·þöÛ©ëÎ:©‰%)Ogn®©ÙX(‘“î-_tä…6aB*ì,ÐZ4óÈ-^(“ùf(©Óhä GÏ^Õ?—l^¦úŠ)“O.øÏ™AfбÍ~äöé+§¾præ¶=bî›UUUË‚K2097vUiñef×{a‡€¶ ™@n ~2ß4åéL=dÕŸ‹Ö¬±Z¹¼ïHªÈ°ÄÍþçߎ^ýî›þ ¹æ?yÏŒÒg:^ÊØ<ž”sß«H$f™ ¤Uï¤ÿWV2ñ†°ƒ€hæ[,ÐÉ|/«‡ÎÔ=Và™H#‹•›"ß“”̰¤Ë”?ÿbŸÊ•KV™ @Ç2§/xâÞ‘Êü5²'s?Ū&™ ¤µ>"ïÄU¥Åÿ;ø…f>¹Å©[ ç˜^Rƒ¾¥Û­6ó°Kñxü5“Ædª›YŸIWŽê™XW¾&È\:†ùß3cá“÷¯Ì_›“ý<–H<d.Ö«æEYY:yvØAÀO4óÈ-A4󙞑ӹzØêÚý,´H<^õ“n˸À¼Å—­Þ¸v}€±ä¸yOÜ=}ÁSMNä“I×TÆ«ï 0ø¼:I×”/:òÌòÅÏn; øf>¹ÅÔ£OxTƒõ=³Æv>­KTÿVÒs™ê^*5hâ¥ßl¨­Ø²%ÀXrÔ¼Çïš¾‹‰|’ÙõñxÕ-Á¥Ÿe²æ"G••ßlv½vh4óÈ5ÝÛmgÓݯŸèzK¶Ûh33óº÷èõ#™fdZ“J6{¦èÌÊÚØÖŠ ³È-³¹}ú§Ơ&¾&vÒ­±DõŒ>­Òœ»xõ¢É'•—<·$ì0ОhækÚ§™Ït³ÆÛ/dfí²?|µnݺíŠDÏ’´0Óšd}íO_ôÕØöÊÍŒ GÌøÇ o—>óÐ ’\ÆE¦¿ïœ ‚× Ùø¼‚‚/”—L|ÈøY.€N€f>¹Æïf>“t¥ÆÛ5>ï‹v‹ÅâÉ”î¤e™Ö¤ë÷{梯ÖmÛ²~CÙd13{í†1Ó–¿òôñj¢‘ÏI·ÄUWñ×àdEö++T´|Þ„-a€ ÐÌ ×øÙÌ—”t¡ÆÙ>î‰UWWo1=CÒÚLkRÉÆaÏŒ>³1±®|M€Ñd!/™l,¾üÜ™kçN?¹É…έŒW1‘€`UËtWÔi¿U¥“.^QRÌÏ÷t:4óÈ5~5óm—Ó9gø´B‹ÅÊå¢'J*Ï´ÆK&‡NüÙY]7.š·$¸d²Iª¾nû3Ÿ¾ ¶zÅñM.4»>KÜL* i¦“ΫoR¶¨ø šøtfya€*ôaJEt¶î³·}Ø Y ‹•÷íÛ÷TYjФaéÖ˜ç øï5?ê~ÒU7ÍÞ÷Ô³¾lBaª¯NÄŸ}Æš†šª¦þûU,^õ÷À‚Ð9™“æ™ÙóŠÚÊÞŸ¼4ì@-hækvkãûåŠê ÝkË|Iƒ¬‹ÅÊûôésšSjŠä†fXÖã­Û¯9bÛÖõ3ûnÑ Míç%“KÞ[f’í1☢ù]Ú!6€vV½aíºç~~v½×PhËR&»,«X0:“V¸øf4æy/­Yðüº°3@6¢™@®iK3_‰":S÷_vPñx|UŸ>}Nqò^•´_†eùsÿ}× KæN;ý†ûNp‘hô“ÅdC}Ý‚Çî~¯tÒ£˜—!I‘¼üÕßüÛý÷98Óž²Ðºù3K^»¾hO3op¦5&Õ9g?ˆÇªŸ 2I5NZ'ÙFIˤÈç´Xò–•LÚv>ÈÎÌÂÎÍW䶨5 }Noª@ßÒ]Vå(d›ž={ÌÏs/IÖuë?`ηîy~ÿ.=zz©dòý§œµð©û†y©äÏ®ÍëÒíƒ=3ç€öK ÀO 'Œ{Þ£w#©©©š 9;'«žT.@î>âœ{%Ew¹0'¸me¥ÅW…¾agOËó…¹%â)åE¢U‘Æ”×Ð#µyÝœÉÛÃιŽf>¹c¬‹h½Ôòœ=¦˜.ÒkhXÈNýû÷ïí¥‹%ÚÔºh^~ùÁgýðÃ%/>¶W²¡qxSkÏÿ×››º÷8Ð× |ežç½ñÇŸM_3wúÉM.tÚ ygÆb5 ŠM¢™@î¸ØõST-zÇt³еâƒ]§äœëÒ·°÷£&û®Û5þê‰ÿ.‰õz¨I·ýø,Y»mÛÄËÎYT³e×v±ôSäŒx<^H0h†¼°@³EZt½nR¦+4ÞîÕøvK„,gfõιóûôé½Lf¿WšðÙgȬ R³O´øìM["}?XØí€î+ònr|.²@õ†µëŠ~ζdC]Ó|¦)Ó·ªªâ[ŠÍÂd>¹c´)§·›±²ZNè~{©Ý3!gô+,ü¾çì!'umá«©Cö2óþß_xtׂü8UnŽô]4¿ë=Vå>ŒÆ> e3^ž7í–«÷6óú7¹ÐôïxUõh3«(4Í|rÇwŽLÅ»XU.Ó75Þ’ 9¥°°ð؈³%5Ýð³ÓýfÝvå÷Ýká»ZkRåæhßEó»س,Ðáæ\¤Í4ɼTjê­WO/ŸñÊI’šú;ç9éºÊxÕ_ƒÊ-Å!¹dè.êïÊÓ9zÀ6’9Ç9;\Rá®ÖEœÛú«¹bÔé_þr³÷–ú LÅN={ôóáèdu´Çœ)]êµ6÷C|Ø€¤µ³§.|íÆË÷yw±t)rN<ŸH0hšùäŽ"÷¤¤ >óÔ$ýEãõñ iôïß¿·—J>+ÙWšZ礚Ë/8cá¾1òøöÈQÉ/™Ùõ°ºeÃŽ1ɵÇ@‡gf3î;sù+O«]O˜žÖ˜ôί©©aZ+€œ@3€ÜQäfJ:îOêd­ñöhX‘Ýz÷îÝ?ÑË’Žij]ׂü%ÿ;:oß!÷oïLu®Ë‚{ì¾)Òï€ö> èH¶WlÚ4ùW篩­ØÒäßgíhò¾=ž¨¾ÆÌ’Ad?ÐÌ w¹5’†ìü§õ2«ñ6;ÌHÈ^ô”lˆ¾*§&¯·=xØ éãÿpñ»äu *›¤Ôæhß™/öyH­ëêÇu¾@‡¶øùÇßyï¿hf»úûRc²ŸÆãÕO |D3€ÜPäò%ÕJŠÊi–LßÖ8Ûv,d§~ýú 5/9EÒ>M,K}ç+_œzÍOÎ:-¨\Ÿeæ*Þè~̪º ýbX€lV¿­ºê¿×þxI¬ìƒcwµÖ¤¥y)ûöÖêê%Ad¿ÑÌ 7\ê†ÉS™LOª§.ÒíVv$d§Ý{öÜ£1/ò–¤ŒWæ:§ª±Eç->óøÃ¾`´LlS´ïŒç{|tƒ¢ÝÃd‹²·_™?í–ß 4/9¨Y/8ý4«z¸c@»¡™@n㎑éd³¿…Ù«wïÞý£M•4"ÓšHÔm~øúщƒ‡ÎØì†F—¿è±Þ_Ý“kwÑÙ¥êj_½¡è½%sN’äšûžIuÎÙ×b±ê·Ú1´šù‚s® Oa¯—%šiM4âÖ?ñ—Ëö¼û°à’5_ʹÿÓãtÅóz ; †5ï¾¹`ê-W&‡·r‹„‹x'UVÖ¼ïk0Í|€œçœs……=srßÏ´¦ ?ZöÌ-WtÙs·>Í»²3$ž¹ OžV[)Üç“Ï7,x§tÏ#ŽË8qÈeµ[¶ü÷?]žX³jd3–›šžØ·:/?5rË–më}Šˆ„€¶ê[Øëš¦ù¢‘ÈÚgoýeÖ7òIRÄÙž£ªß,èîÕnýèÙ»ãošöÁkÏV„™ hæy6÷Ñ»Þ}êÂS šÓÈFÖÝô‹ó|ñàáM]¥»w²1:i¯½\7£@»£™Óúöíu²IÌTwÎU>v㥠{ô/ÌúF¾8ó†\PýÚú¨¼úusg”,~þÑãz Êçlt(+J–?ñƒãKߟ0îX™îjý^{ôçÕ{Ûã+_qä½×^xâð!»¿ÝÄòcjjz>àc\hw4rV¯^½HîIIy–l»ïwnÜoèÀ}2Ô³VWk<ì˜ef¼6öÒ=%ôÚ}h×°3~¨ßVxõº‹§M¾ò‚a 5U‡îj}$Ýpã˜ófM¼í—ÇöèÞGÚqµöã7^úÅ>=»¿ŸñEs?èÓ§÷Õ>F€vE3 gE£‘»dÚ3C¹áÏ?µôè/ ;8ÐP>ihl¬½êº›™y»IÒ Ãv& -¼dcröC·¾ýä÷ŽK­[øÎÉ’òwõÎ{í1ãõû®é~æ ‡ù³µ‚ü¼‚gÿvÅЂ‚¼•™ÞwÒ_ûžÑÆègfag Åúöí}ŽLÅ™ê?nä“Ô³ßÀÕaæZkó WýçÇ'/˜þ÷ßÞÌF¾ú“>è©ã¿Gsù$iï=wrË/¿9’9 IDAT·NRc†%ý“‘9ç\³ƒ@˜ÌÈ9}úôþ­“nJW‹F#ëÞ¸ïÚÂݺô :—¦ÎY2ÿê;Ÿô¼‹Žætó)Ö†…ï.uÿŸªâkW¥fNⓤžÝ ÿé²ó“'qÀaíO| xÚäisON[tÚŸßõ€Í›7×´gh)šù9£oß^ÉÜÒ”’o»rÓ}ªÞ˜]:ÿš»&ª4D5Å9·uÏ#G.ù³ë¿Ðkààt Vñ’IoéË,xüînõ5UµäÝ‚‚¼UW~ÿk›Ï;íØcs®½2~Ä3óNó×EÕÛjMW7¹?Çã‰ëÚ;´Í|€œ°óŠÝ-’ú~¶¶{¿^s^¼óêcBˆÕ&n¬\wÞoîìffý>[3©ÎÌuvžIc$õjáö©^‡Ì>êû?Ï~Ê×t‘hÔŸÔèl¶UWÍ~ø¶ù+^Ÿ¸Ÿ—Jµ¨a6‰¬ûé¹§,¿øÜ“OˆF"¾4¬6×üe«—]rãƒûKŠ|¶fR]$’w`eeå‡Af€¦ÐÌÈ }úô9ÆÉ›®våΘùý3FŽ :S[Ô54n?ýÒ›×Ö54véöX¢êW’TXXØ/"ïr9÷ IŸküÛ¥hdÓàÃŽ[zÔ¯¼Û~‡ìצàè46”¼·hî#wVlù`áQ2ëÙÂ×ë/øêqï^ñý¯}9?/ZÐ.›á‡×Ý;ciù†Ò?ñw ²Í|€œÐ·oá2»#M)õÆ}×VõîÙísû²•™ÙwûwVoØ’©±2å逪ªªŠO>Ü}÷Ý{66ÖÉì’Úš³óºtY¼ßÿœ»áÈï_vX×>ý´ft\Ûã•ñ’§Ç/ùàå§$êÚÒøiÅ·_µnÈî}C½êyk¼zë™—ßÒÅ,ídËD$š?´¢¢¢*ð`Æç® ™éÀtÏ»w-ø —ù$éOM~«‰F>™ô›Ï6òIÒæÍ›kb±ª¿Å5ûšì|“fµôìd}ýÁKÿûŸÓžüÑI}&üô+ï½ÿô¸õUñXK÷AÇáýÿöî<¸êòÞãøçù“BΖ@A@DÜPP ¶ÕÚÖÚzk{mÝÚ[[»ÙÚÚö¶ÞzÕ;c«V«½ÚQÜÛ^[EÜ­VpaS©‚È¢\ÀÙC’“s~ßû‡Ö±á„%Ë9I}¿f2Î÷9¿ç“É@2Ãgž§Ï¯Ÿûà‹³¿ñé%÷œ5£zõÃ<¶‹E>Ir·=ôÌÆn Ø£¡3§¶´ƒqÄ÷ó_+i Ø Næô ±hdžd'´}ìð!‹þïŠï¿F³š»dÕË¿¸áÞI’,y6™Ê|Òöòt]$2Õœ.4Ù$Ut2V[ÿhÝʃN85{ð)g¯®«ÔÉç 0¿PØò÷gV®|à®ô¶×–WÁïì÷üEI#$ n?è_U±vá­—ŒïRÐnjÚ™<á[¿ª4 Èxi"™ž\êLPL°ÜØ;6ªØ«£œ/u’ÎZ¶vóš‹oœ=Jùr¯ím‘O’v¤R/Húò šš!ù ÷u“Γ4r£U4'w¹bÎíZ1çv¿²&²âÀãg&>ù«£¢ÃGuê:_ô>~>ß¶éùy+VßçÎí_`æÞÉGµÊt/wc*•ú{m4|•I?i¿¨¹¥müÊ ë3bl£wId@uôˆñ,xyíæ‰DƤR© %íp2 OˆEÃoJÖþõ3g·àg|¦XI§WY¸|í+?¾öîÍê`‰Éé´D"ý`WöqÎy±pøÓ&;_N3ÕÅâ¾ 6Ô™¸iì §ö9cæ„Ê5åG/”kʤ7-||ÍšÇþҖؼî3ë•ÔÖàänj+Ø­™L¦ñ¯Æb5‡É¼åÅÞ1qÌð…w^úÍß³{¬Üаîk—ÞR´T褋ãÉô¥Îíq2 opòT¤ ]éÃ웿.|åÅÿž5g¢¤þ­qÒUñ.ù$ÉÌ|IKz|PMͶ ï+2#iRgžççó#×.Ѹv¹ßxY¾²&²bø‘Óã}æ‹õƒ'9Î ù]¢ñ ùüÛËŸ_½þÉûão¯xn`.›/iJi’ž‘³“ÉìCfVh¿ ‘Ⱦ‹†WIšØ~öêÆ-‡ïlÉ5UWU»â¶d3bleepc.—Ý~fÒ‰’(ó(;þЧ5µ´úåΰ;·=4ÑM÷=5EREGkœÜìD*ý‹îÞ{[6»UÒµ’®Åj—ïΑsgHÜÉGsÙÔ¤ ÓÆIrMý£µ+Oœœ9í¤ðþGM[Ñ¿¦¦Û¾ì‘_Èç·¯_µáOlkxa^UvûÖ 2ëTq³m2û³çkÖŽLfõW››%g¿ÛåeSèïyrþÏÎùìñÝ©Kf6nËS/¾ºK™OÒTç\¥™åJ >„kv}B,Z*¹#Û¿~Ì!£çßøŸç”½(Ô^.×Öò­+îX¾jã›Sw»Ð´0™Îœdf-¥Èåœ D"5Ç{r_2é4IõÝøøB°_ÿ µc~w䔃Ã&O6j„ó<¯÷øHËlÝòfÃ’§·¼õâ‚\ãÆÕ±\65FrÕÝôøVI:Ow%™'̬moßXWWö moIÚ¥Ìðo/¾ã—ƒ@e7åì”K×,ÿÑoï>¼ØÌó5mG:ý\©3À‡Qæô ±hxŽÞ+Ÿý“Aµá»þ¢£Ë©CÞ}ýÜKgZÛÚÚíBÓŠ~U3·mÛ–-Q´ò^±oÀ'¾ùµê;¶2óvóVÉž6¹{ÍÜý©T*ÞÙE£‘ß;Ù·‹Í¾õÅ÷ùOï|Ì®k͵µLûúåNR¿ö3“ŸLfn.C,øe>@ŸP _mÒEí_wN™ç︴ à•ýêx3³Y>³øÖûŸ=BÒ€=,º¢²êÔrùÚsÎyápøhçìsN:EÒ¡=¸]>PÙoSuÝ í‘ýlxÐÄ`ýØCõŽZ«ïî²Z¯e¾ïïÜñî¶DÆmÛ^[žÚ¾v¥oX? 5âòÃ$¹Ú:/i“Ín+hN&“i쎇†ÃáqO¯J ´Ÿ^ÃÂÛ.\ ìR¤+¥é_ÿŸu-¹ÜØ"£kÉô.ÿ¾@)Qæô ±XøT™,6ûÍ…g,?þȃ‹^ŸY*k6½µá»¿þÃÎLSó¤=­5¹¿¤RésͬµÙ:#Žò¬0ÓäN’Ó'$…J²±s©Êªê·úœ ï72Úo/6bL¿ðÐB¡!Ãêªk t^`—²Xo”omnjÚÑϾÓL4¬Ë$ÞXŸKmy#Ýþö€ÖlºÖÏ· •Tª«g“&÷„gzØwî‰d2™è‰Mj£‘{Lö¥b³?éØù=³¬WbŸvÑuÏ7lÝqì.ÓÉTúÔ2D€Pæô µµµóóÛ%írßÈ¡uÏ;òÇ•!–²Í­Í¿¼ñÞå‹^Y?YRÅ–›I¿J¥2ÿe}è°s®" M5s'9³O™t”Š|J$ï¼àö`UÕö~B;+„rýB‘BU´Îª¢Ý€èÀ`U¬¾²ºv`Uÿh]u ª¥œ§ªP$$IÁªêê@EÅ^èò­ÍMmÙìÎ\KSKs2žÍïLçr™L[KSº­9Ù˜Ïe3~s¢Ñ²ïZS;úåš3Õùæ–ˆ_È×jϧ3ö¤‚¤e&ÍuÎæ&“ÙÅfÖÖÓ›Æb5‡É¼—UäTAçÜŽgg]\9 ¿ÒC‹øÞ•wÍaå†] …NÏ%éieˆ Ìè3b‘ð9Í(2ÊÝûë Þ5là¥ÊRðýÂ].záæ9Oö}È^¼%#gç&™û{<\«¯¯år¹ãœÓ gö1s:ÚIUåεòι®8öͪúà×ða¾¤Urn¡™?ßÌ{*•JÅË$ Ï‘tZ±ÙäƒGοéâÿ(Ûé|ßpϳ{aÕÇw¸å‰dꈒ€)ש:ì3'7ËdÅÊ|•ß»ú®m]÷ã/ó|¿ð‡G¿pëOËå {{’×K__M§3ëz4\‰466f$ýíý9çªb±Ðd3MqæŽ6Ù1’F•5äžÍ,ú?ìrŒ\ï—ì%“{Ñ3·XÀâx<ž*w(I*øúyÀÓ)*rRåKk6M_¹aËk‡Ž>® ÑT]]Uô[ídýKÚãd>@Ÿáœ«ŒDB›Tô$¼Ó?5eþOÏ>¹GNý*ø¾÷ãÏ¿tÓ}óåò…‘{ù¶¼“®I¤2—˜Y®'rõV¡P¨¾ÂóŽ6ùGɹI’“4Z’Wæh}MAÒ2­vNk|Ù2)°4™Ln,w°Ý‰Å"7Èì»Åf5ýû½úÔÍ¿ð¼@©su|2ŸV&’éI¥ÎÆÉ|€>ÃÌrÑhøZIW›Ïž»dúÐÚèsg~vÚqݵçΖ\Ó-<½ôÞ'—ŒÈå ÇìuVéyçüïÄÙåÝ•¥/Éd2’þúþ‡$iÈ!š››qÎsf̹qN'éI%/võ&&µ8é5'·ÖÌ_cNk<ÏÖ&M¯™Yk¹óí«|Þ¿,pgHªm?Ë6·rù¬ûŸ½ìü/~¼Ô¹RM͘(i(‚“ù}Šs®" -“4±ƒ%6íˆqó¯úþ—¦TVTtúêÌ•^»áÞy[_^»ùp3‹ìÃ[w˜ôóT*s›™ùÝÿ£Ä9ׯ¶¶fŒïkœäFI6Rr#e©÷Š~¡ræë&ÍNj0¹-2{SžÛlf[$o‹¤©Tj“™Ê²;E£áóœ4«ƒqîöKÎÛ0iìˆ ¥Ìô¹ ¯YòvcrÊ.ÓC‰Túó¥ÌíQæô9µáð4ó´@»¹²Õó¼­gΜ¶þÜSfª® ïé™ffË×5¬=ï…w.[?´%—»±LÒ_?K§ÓÛ÷ñ½ØH$R úû ¶Ÿ³À~òl| “gƒÌ\'Åì½àbzïs)®òm5)áœâ’òßÿ,‹KJ8yqó¬A*lÉç½7ß?©ð#Å9çE£¡2M+6¯¬ ¾>ï÷?\]U9 T™¦ž{éæ|¾p@û×tu<™þi©r@1”ù}R,ù¡Ì~»KÛb¡šU‡‘³}`Äàúþ•¯1™nmjÎV¬k(¬kx§&žj:À7؉(&éçù—ÅãÙx?ºYmmmÄóZjòùÊ~’¢ßïWðüÎy5fVñužyAßYH’œS«™í”$Ïw;ç·HRÀÙ‚sm’dfé`°­¥¢"ߺukS¾´>§.šP¸¥Nª*6?`Hýss~sA·]‹½;Û©wg~ÿ7ƒ$írÕ®3wV<•úS)r@G(óú¬X,t½Ì]P¦íMÒCrþe‰Dvy™2½^,ùÌ®ëh~ú§¦ÌÿÙ9Ÿ=¾§sÜpÏÜ…w>²`F±™°‰;vd^íé °;¥¸†€‘Lf/4é ½W¬+“éa¹ÂQ‰dú ù€ÝK&Ó¿“ܼ޿³ç.™þÈü¥Kz:Çœ§–ÔŸXC<ž]ÝÓûÀžPæôYfæ'“é‹ÙW$e{x·NºÙ+ØÄD*}j"ÑôrÏîük03 VäÏ1ikK—Ýòà!ËÖnZÕSæ.yuYfgë„¢Cç=bU   Ìèóâ©Ì_‚…q2Ý)ÉïæÇ/•¹ |ó†Ç“éówd2œàì£ÆÆ¦·³/KÊw°¤æü_Ý>hÝæ­¯w÷ÞmùBîò›ï¯éhîìîîÞ:Ãq à_ImmÍ$3÷ ™;CÒÀN<—ô’œ{Ü9ÿ¾x<Óc§…5±XøG2]ÓÑÜ ¸mºìüÌØ‘CGwמ?¼úó½²þø¢CÓ¢D*=£»ö€® Ìø—䜫 ‡ÃŸðäO‘玖éIu’ÂZ–•ô®¤•rn•™ÿŠï»gÓéôö²„>bÑð-’¾ÑÑÜsnû]—3yð¨ýÇtu¯[|vÑÍ÷=5M’+6w¾N‰§Óvuè”ù)ι`$ ¥R©”™u÷•¼öཿƒ¡Gœô™Ž×(óÓsN~õô§Níì>·?4ñïgÏ›*)ÐÁ’§“©Ì‰Æ/Ãz Ê|(©úúúP¾­õiI“w³ÌŽ9dô‚ërÖ±Á@åÞ>»-_È}÷Ê;Ÿ_¶fSñ«uß“5y“’Éä{ze>”\49æIîÈÝ­ ¼-ß9ý¤†3O>îXÏ9¯£uff÷Î]²äú»Ÿ’ËFîî™Nöíx2sS'£@ Ì€²ˆD"µžÓS’¾§µ•7Ž™p`ÿxLôÑÇW÷«ì—Ì6eVm|óǽ’ZüʺÑù‚?|›š®M¤Ò?î–/ºe>”M4y*üÙä>ÝÓ{™ìÏ©TöL3ó{z/ØW^Uô´d2™H¤²';éÊžÜÇI³R©ìÙùôVœÌ€^!‰œ%g7J u×3MjñdÆ“™›ºë™Ð8™½B"•ú£ó‚ez¤;žçdOš¹C)òè 8™½Nmmx†ùî¿%û¤$·oFN×%é‡{"ôÊ|èµ"‘ÈÏÓY’(ÓdI•E–åå´DrÏ8W˜gW”:'te>ô Ç»þ;“5£ü J.êœâù¼#‘LÖ-Ö\î|Дù(3¯Üø¨£Ì@™Qæ Ì(óPf”ù(3Ê|”e>ÊŒ2eF™€2£Ì@™Qæ Ì(óPf”ù(3Ê|”e>ÊŒ2eF™€2£Ì@™Qæ Ì(óPf”ù(3Ê|”Ùÿ¥C‡VðëŠIEND®B`‚quic-go-0.59.0/buffer_pool.go000066400000000000000000000043241513066070600160110ustar00rootroot00000000000000package quic import ( "sync" "github.com/quic-go/quic-go/internal/protocol" ) type packetBuffer struct { Data []byte // refCount counts how many packets Data is used in. // It doesn't support concurrent use. // It is > 1 when used for coalesced packet. refCount int } // Split increases the refCount. // It must be called when a packet buffer is used for more than one packet, // e.g. when splitting coalesced packets. func (b *packetBuffer) Split() { b.refCount++ } // Decrement decrements the reference counter. // It doesn't put the buffer back into the pool. func (b *packetBuffer) Decrement() { b.refCount-- if b.refCount < 0 { panic("negative packetBuffer refCount") } } // MaybeRelease puts the packet buffer back into the pool, // if the reference counter already reached 0. func (b *packetBuffer) MaybeRelease() { // only put the packetBuffer back if it's not used any more if b.refCount == 0 { b.putBack() } } // Release puts back the packet buffer into the pool. // It should be called when processing is definitely finished. func (b *packetBuffer) Release() { b.Decrement() if b.refCount != 0 { panic("packetBuffer refCount not zero") } b.putBack() } // Len returns the length of Data func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) } func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) } func (b *packetBuffer) putBack() { if cap(b.Data) == protocol.MaxPacketBufferSize { bufferPool.Put(b) return } if cap(b.Data) == protocol.MaxLargePacketBufferSize { largeBufferPool.Put(b) return } panic("putPacketBuffer called with packet of wrong size!") } var bufferPool, largeBufferPool sync.Pool func getPacketBuffer() *packetBuffer { buf := bufferPool.Get().(*packetBuffer) buf.refCount = 1 buf.Data = buf.Data[:0] return buf } func getLargePacketBuffer() *packetBuffer { buf := largeBufferPool.Get().(*packetBuffer) buf.refCount = 1 buf.Data = buf.Data[:0] return buf } func init() { bufferPool.New = func() any { return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)} } largeBufferPool.New = func() any { return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)} } } quic-go-0.59.0/buffer_pool_test.go000066400000000000000000000021111513066070600170400ustar00rootroot00000000000000package quic import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestBufferPoolSizes(t *testing.T) { buf1 := getPacketBuffer() require.Equal(t, protocol.MaxPacketBufferSize, cap(buf1.Data)) require.Zero(t, buf1.Len()) buf1.Data = append(buf1.Data, []byte("foobar")...) require.Equal(t, protocol.ByteCount(6), buf1.Len()) buf2 := getLargePacketBuffer() require.Equal(t, protocol.MaxLargePacketBufferSize, cap(buf2.Data)) require.Zero(t, buf2.Len()) } func TestBufferPoolRelease(t *testing.T) { buf1 := getPacketBuffer() buf1.Release() // panics if released twice require.Panics(t, func() { buf1.Release() }) // panics if wrong-sized buffers are passed buf2 := getLargePacketBuffer() buf2.Data = make([]byte, 10) // replace the underlying slice require.Panics(t, func() { buf2.Release() }) } func TestBufferPoolSplitting(t *testing.T) { buf := getPacketBuffer() buf.Split() buf.Split() // now we have 3 parts buf.Decrement() buf.Decrement() buf.Decrement() require.Panics(t, func() { buf.Decrement() }) } quic-go-0.59.0/client.go000066400000000000000000000065661513066070600147770ustar00rootroot00000000000000package quic import ( "context" "crypto/tls" "errors" "net" "github.com/quic-go/quic-go/internal/protocol" ) // make it possible to mock connection ID for initial generation in the tests var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. // It resolves the address, and then creates a new UDP connection to dial the QUIC server. // When the QUIC connection is closed, this UDP connection is closed. // See [Dial] for more details. func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (*Conn, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err } udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } tr, err := setupTransport(udpConn, tlsConf, true) if err != nil { return nil, err } conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, false) if err != nil { tr.Close() return nil, err } return conn, nil } // DialAddrEarly establishes a new 0-RTT QUIC connection to a server. // See [DialAddr] for more details. func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (*Conn, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err } udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } tr, err := setupTransport(udpConn, tlsConf, true) if err != nil { return nil, err } conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true) if err != nil { tr.Close() return nil, err } return conn, nil } // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. // See [Dial] for more details. func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) { dl, err := setupTransport(c, tlsConf, false) if err != nil { return nil, err } conn, err := dl.DialEarly(ctx, addr, tlsConf, conf) if err != nil { dl.Close() return nil, err } return conn, nil } // Dial establishes a new QUIC connection to a server using a net.PacketConn. // If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does), // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP // will be used instead of ReadFrom and WriteTo to read/write packets. // The [tls.Config] must define an application protocol (using tls.Config.NextProtos). // // This is a convenience function. More advanced use cases should instantiate a [Transport], // which offers configuration options for a more fine-grained control of the connection establishment, // including reusing the underlying UDP socket for multiple QUIC connections. func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) { dl, err := setupTransport(c, tlsConf, false) if err != nil { return nil, err } conn, err := dl.Dial(ctx, addr, tlsConf, conf) if err != nil { dl.Close() return nil, err } return conn, nil } func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } return &Transport{ Conn: c, createdConn: createdPacketConn, isSingleUse: true, }, nil } quic-go-0.59.0/client_test.go000066400000000000000000000047671513066070600160370ustar00rootroot00000000000000package quic import ( "context" "crypto/tls" "net" "runtime" "testing" "time" "github.com/stretchr/testify/require" ) func TestDial(t *testing.T) { t.Run("Dial", func(t *testing.T) { testDial(t, func(ctx context.Context, addr net.Addr) error { conn := newUDPConnLocalhost(t) _, err := Dial(ctx, conn, addr, &tls.Config{}, nil) return err }, false, ) }) t.Run("DialEarly", func(t *testing.T) { testDial(t, func(ctx context.Context, addr net.Addr) error { conn := newUDPConnLocalhost(t) _, err := DialEarly(ctx, conn, addr, &tls.Config{}, nil) return err }, false, ) }) t.Run("DialAddr", func(t *testing.T) { testDial(t, func(ctx context.Context, addr net.Addr) error { _, err := DialAddr(ctx, addr.String(), &tls.Config{}, nil) return err }, true, ) }) t.Run("DialAddrEarly", func(t *testing.T) { testDial(t, func(ctx context.Context, addr net.Addr) error { _, err := DialAddrEarly(ctx, addr.String(), &tls.Config{}, nil) return err }, true, ) }) } func testDial(t *testing.T, dialFn func(context.Context, net.Addr) error, shouldCloseConn bool, ) { server := newUDPConnLocalhost(t) ctx, cancel := context.WithCancel(context.Background()) errChan := make(chan error, 1) go func() { errChan <- dialFn(ctx, server.LocalAddr()) }() server.SetReadDeadline(time.Now().Add(time.Second)) _, addr, err := server.ReadFrom(make([]byte, 1500)) require.NoError(t, err) cancel() select { case err := <-errChan: require.ErrorIs(t, err, context.Canceled) case <-time.After(time.Second): t.Fatal("timeout") } if shouldCloseConn { // The socket that the client used for dialing should be closed now. // Binding to the same address would error if the address was still in use. require.Eventually(t, func() bool { conn, err := net.ListenUDP("udp", addr.(*net.UDPAddr)) if err != nil { return false } conn.Close() return true }, scaleDuration(200*time.Millisecond), scaleDuration(10*time.Millisecond)) require.False(t, areTransportsRunning()) return } // The socket that the client used for dialing should not be closed now. // Binding to the same address will error if the address was still in use. _, err = net.ListenUDP("udp", addr.(*net.UDPAddr)) require.Error(t, err) if runtime.GOOS == "windows" { require.ErrorContains(t, err, "bind: Only one usage of each socket address") } else { require.ErrorContains(t, err, "address already in use") } require.False(t, areTransportsRunning()) } quic-go-0.59.0/closed_conn.go000066400000000000000000000034301513066070600157720ustar00rootroot00000000000000package quic import ( "math/bits" "net" "sync/atomic" "github.com/quic-go/quic-go/internal/utils" ) // A closedLocalConn is a connection that we closed locally. // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. type closedLocalConn struct { counter atomic.Uint32 logger utils.Logger sendPacket func(net.Addr, packetInfo) } var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler { return &closedLocalConn{ sendPacket: sendPacket, logger: logger, } } func (c *closedLocalConn) handlePacket(p receivedPacket) { n := c.counter.Add(1) // exponential backoff // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving if bits.OnesCount32(n) != 1 { return } c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", n) c.sendPacket(p.remoteAddr, p.info) } func (c *closedLocalConn) destroy(error) {} func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {} // A closedRemoteConn is a connection that was closed remotely. // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // We can just ignore those packets. type closedRemoteConn struct{} var _ packetHandler = &closedRemoteConn{} func newClosedRemoteConn() packetHandler { return &closedRemoteConn{} } func (c *closedRemoteConn) handlePacket(receivedPacket) {} func (c *closedRemoteConn) destroy(error) {} func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {} quic-go-0.59.0/closed_conn_test.go000066400000000000000000000015401513066070600170310ustar00rootroot00000000000000package quic import ( "net" "testing" "github.com/quic-go/quic-go/internal/utils" "github.com/stretchr/testify/require" ) func TestClosedLocalConnection(t *testing.T) { written := make(chan net.Addr, 1) conn := newClosedLocalConn(func(addr net.Addr, _ packetInfo) { written <- addr }, utils.DefaultLogger) addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337} for i := 1; i <= 20; i++ { conn.handlePacket(receivedPacket{remoteAddr: addr}) if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { select { case gotAddr := <-written: require.Equal(t, addr, gotAddr) // receive the CONNECTION_CLOSE default: t.Fatal("expected to receive address") } } else { select { case gotAddr := <-written: t.Fatalf("unexpected address received: %v", gotAddr) default: // Nothing received, which is expected } } } } quic-go-0.59.0/codecov.yml000066400000000000000000000005701513066070600153240ustar00rootroot00000000000000coverage: round: nearest ignore: - http3/gzip_reader.go - example/ - interop/ - internal/handshake/cipher_suite.go - internal/mocks/ - internal/utils/linkedlist/linkedlist.go - internal/testdata - internal/synctest - testutils/ - fuzzing/ - metrics/ status: project: default: threshold: 0.5 patch: false quic-go-0.59.0/config.go000066400000000000000000000104721513066070600147550ustar00rootroot00000000000000package quic import ( "fmt" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // Clone clones a Config. func (c *Config) Clone() *Config { copy := *c return © } func (c *Config) handshakeTimeout() time.Duration { return 2 * c.HandshakeIdleTimeout } func (c *Config) maxRetryTokenAge() time.Duration { return c.handshakeTimeout() } func validateConfig(config *Config) error { if config == nil { return nil } const maxStreams = 1 << 60 if config.MaxIncomingStreams > maxStreams { config.MaxIncomingStreams = maxStreams } if config.MaxIncomingUniStreams > maxStreams { config.MaxIncomingUniStreams = maxStreams } if config.MaxStreamReceiveWindow > quicvarint.Max { config.MaxStreamReceiveWindow = quicvarint.Max } if config.MaxConnectionReceiveWindow > quicvarint.Max { config.MaxConnectionReceiveWindow = quicvarint.Max } if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize { config.InitialPacketSize = protocol.MinInitialPacketSize } if config.InitialPacketSize > protocol.MaxPacketBufferSize { config.InitialPacketSize = protocol.MaxPacketBufferSize } // check that all QUIC versions are actually supported for _, v := range config.Versions { if !protocol.IsValidVersion(v) { return fmt.Errorf("invalid QUIC version: %s", v) } } return nil } // populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateConfig(config *Config) *Config { if config == nil { config = &Config{} } versions := config.Versions if len(versions) == 0 { versions = protocol.SupportedVersions } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout } idleTimeout := protocol.DefaultIdleTimeout if config.MaxIdleTimeout != 0 { idleTimeout = config.MaxIdleTimeout } initialStreamReceiveWindow := config.InitialStreamReceiveWindow if initialStreamReceiveWindow == 0 { initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData } maxStreamReceiveWindow := config.MaxStreamReceiveWindow if maxStreamReceiveWindow == 0 { maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow } initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow if initialConnectionReceiveWindow == 0 { initialConnectionReceiveWindow = protocol.DefaultInitialMaxData } maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow if maxConnectionReceiveWindow == 0 { maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow } maxIncomingStreams := config.MaxIncomingStreams if maxIncomingStreams == 0 { maxIncomingStreams = protocol.DefaultMaxIncomingStreams } else if maxIncomingStreams < 0 { maxIncomingStreams = 0 } maxIncomingUniStreams := config.MaxIncomingUniStreams if maxIncomingUniStreams == 0 { maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } initialPacketSize := config.InitialPacketSize if initialPacketSize == 0 { initialPacketSize = protocol.InitialPacketSize } return &Config{ GetConfigForClient: config.GetConfigForClient, Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, InitialConnectionReceiveWindow: initialConnectionReceiveWindow, MaxConnectionReceiveWindow: maxConnectionReceiveWindow, AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, InitialPacketSize: initialPacketSize, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, EnableStreamResetPartialDelivery: config.EnableStreamResetPartialDelivery, Allow0RTT: config.Allow0RTT, Tracer: config.Tracer, } } quic-go-0.59.0/config_test.go000066400000000000000000000147661513066070600160260ustar00rootroot00000000000000package quic import ( "context" "reflect" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConfigValidation(t *testing.T) { t.Run("nil config", func(t *testing.T) { require.NoError(t, validateConfig(nil)) }) t.Run("config with a few values set", func(t *testing.T) { conf := populateConfig(&Config{ MaxIncomingStreams: 5, MaxStreamReceiveWindow: 10, }) require.NoError(t, validateConfig(conf)) require.Equal(t, int64(5), conf.MaxIncomingStreams) require.Equal(t, uint64(10), conf.MaxStreamReceiveWindow) }) t.Run("stream limits", func(t *testing.T) { conf := &Config{ MaxIncomingStreams: 1<<60 + 1, MaxIncomingUniStreams: 1<<60 + 2, } require.NoError(t, validateConfig(conf)) require.Equal(t, int64(1<<60), conf.MaxIncomingStreams) require.Equal(t, int64(1<<60), conf.MaxIncomingUniStreams) }) t.Run("flow control windows", func(t *testing.T) { conf := &Config{ MaxStreamReceiveWindow: quicvarint.Max + 1, MaxConnectionReceiveWindow: quicvarint.Max + 2, } require.NoError(t, validateConfig(conf)) require.Equal(t, uint64(quicvarint.Max), conf.MaxStreamReceiveWindow) require.Equal(t, uint64(quicvarint.Max), conf.MaxConnectionReceiveWindow) }) t.Run("initial packet size", func(t *testing.T) { // not set conf := &Config{InitialPacketSize: 0} require.NoError(t, validateConfig(conf)) require.Zero(t, conf.InitialPacketSize) // too small conf = &Config{InitialPacketSize: 10} require.NoError(t, validateConfig(conf)) require.Equal(t, uint16(1200), conf.InitialPacketSize) // too large conf = &Config{InitialPacketSize: protocol.MaxPacketBufferSize + 1} require.NoError(t, validateConfig(conf)) require.Equal(t, uint16(protocol.MaxPacketBufferSize), conf.InitialPacketSize) }) } func TestConfigHandshakeIdleTimeout(t *testing.T) { c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} require.Equal(t, 11*time.Second, c.handshakeTimeout()) } func configWithNonZeroNonFunctionFields(t *testing.T) *Config { t.Helper() c := &Config{} v := reflect.ValueOf(c).Elem() typ := v.Type() for i := 0; i < typ.NumField(); i++ { f := v.Field(i) if !f.CanSet() { // unexported field; not cloned. continue } switch fn := typ.Field(i).Name; fn { case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]Version{1, 2, 3})) case "ConnectionIDLength": f.Set(reflect.ValueOf(8)) case "ConnectionIDGenerator": f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) case "HandshakeIdleTimeout": f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": f.Set(reflect.ValueOf(time.Hour)) case "TokenStore": f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) case "InitialStreamReceiveWindow": f.Set(reflect.ValueOf(uint64(1234))) case "MaxStreamReceiveWindow": f.Set(reflect.ValueOf(uint64(9))) case "InitialConnectionReceiveWindow": f.Set(reflect.ValueOf(uint64(4321))) case "MaxConnectionReceiveWindow": f.Set(reflect.ValueOf(uint64(10))) case "MaxIncomingStreams": f.Set(reflect.ValueOf(int64(11))) case "MaxIncomingUniStreams": f.Set(reflect.ValueOf(int64(12))) case "StatelessResetKey": f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) case "KeepAlivePeriod": f.Set(reflect.ValueOf(time.Second)) case "EnableDatagrams": f.Set(reflect.ValueOf(true)) case "DisableVersionNegotiationPackets": f.Set(reflect.ValueOf(true)) case "InitialPacketSize": f.Set(reflect.ValueOf(uint16(1350))) case "DisablePathMTUDiscovery": f.Set(reflect.ValueOf(true)) case "Allow0RTT": f.Set(reflect.ValueOf(true)) case "EnableStreamResetPartialDelivery": f.Set(reflect.ValueOf(true)) default: t.Fatalf("all fields must be accounted for, but saw unknown field %q", fn) } } return c } func TestConfigClone(t *testing.T) { t.Run("function fields", func(t *testing.T) { var calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ GetConfigForClient: func(info *ClientInfo) (*Config, error) { return nil, assert.AnError }, AllowConnectionWindowIncrease: func(*Conn, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, Tracer: func(context.Context, bool, ConnectionID) qlogwriter.Trace { calledTracer = true return nil }, } c2 := c1.Clone() c2.AllowConnectionWindowIncrease(nil, 1234) require.True(t, calledAllowConnectionWindowIncrease) _, err := c2.GetConfigForClient(&ClientInfo{}) require.ErrorIs(t, err, assert.AnError) c2.Tracer(context.Background(), true, protocol.ConnectionID{}) require.True(t, calledTracer) }) t.Run("non-function fields", func(t *testing.T) { c := configWithNonZeroNonFunctionFields(t) require.Equal(t, c, c.Clone()) }) t.Run("returns a copy", func(t *testing.T) { c1 := &Config{MaxIncomingStreams: 100} c2 := c1.Clone() c2.MaxIncomingStreams = 200 require.EqualValues(t, 100, c1.MaxIncomingStreams) }) } func TestConfigDefaultValues(t *testing.T) { // if set, the values should be copied c := configWithNonZeroNonFunctionFields(t) require.Equal(t, c, populateConfig(c)) // if not set, some fields use default values c = populateConfig(&Config{}) require.Equal(t, protocol.SupportedVersions, c.Versions) require.Equal(t, protocol.DefaultHandshakeIdleTimeout, c.HandshakeIdleTimeout) require.Equal(t, protocol.DefaultIdleTimeout, c.MaxIdleTimeout) require.EqualValues(t, protocol.DefaultInitialMaxStreamData, c.InitialStreamReceiveWindow) require.EqualValues(t, protocol.DefaultMaxReceiveStreamFlowControlWindow, c.MaxStreamReceiveWindow) require.EqualValues(t, protocol.DefaultInitialMaxData, c.InitialConnectionReceiveWindow) require.EqualValues(t, protocol.DefaultMaxReceiveConnectionFlowControlWindow, c.MaxConnectionReceiveWindow) require.EqualValues(t, protocol.DefaultMaxIncomingStreams, c.MaxIncomingStreams) require.EqualValues(t, protocol.DefaultMaxIncomingUniStreams, c.MaxIncomingUniStreams) require.False(t, c.DisablePathMTUDiscovery) require.Nil(t, c.GetConfigForClient) } func TestConfigZeroLimits(t *testing.T) { config := &Config{ MaxIncomingStreams: -1, MaxIncomingUniStreams: -1, } c := populateConfig(config) require.Zero(t, c.MaxIncomingStreams) require.Zero(t, c.MaxIncomingUniStreams) } quic-go-0.59.0/conn_id_generator.go000066400000000000000000000143131513066070600171650ustar00rootroot00000000000000package quic import ( "fmt" "slices" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) type connRunnerCallbacks struct { AddConnectionID func(protocol.ConnectionID) RemoveConnectionID func(protocol.ConnectionID) ReplaceWithClosed func([]protocol.ConnectionID, []byte, time.Duration) } // The memory address of the Transport is used as the key. type connRunners map[connRunner]connRunnerCallbacks func (cr connRunners) AddConnectionID(id protocol.ConnectionID) { for _, c := range cr { c.AddConnectionID(id) } } func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) { for _, c := range cr { c.RemoveConnectionID(id) } } func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte, expiry time.Duration) { for _, c := range cr { c.ReplaceWithClosed(ids, b, expiry) } } type connIDToRetire struct { t monotime.Time connID protocol.ConnectionID } type connIDGenerator struct { generator ConnectionIDGenerator highestSeq uint64 connRunners connRunners activeSrcConnIDs map[uint64]protocol.ConnectionID connIDsToRetire []connIDToRetire // sorted by t initialClientDestConnID *protocol.ConnectionID // nil for the client statelessResetter *statelessResetter queueControlFrame func(wire.Frame) } func newConnIDGenerator( runner connRunner, initialConnectionID protocol.ConnectionID, initialClientDestConnID *protocol.ConnectionID, // nil for the client statelessResetter *statelessResetter, callbacks connRunnerCallbacks, queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { m := &connIDGenerator{ generator: generator, activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), statelessResetter: statelessResetter, connRunners: map[connRunner]connRunnerCallbacks{runner: callbacks}, queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID return m } func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { if m.generator.ConnectionIDLen() == 0 { return nil } // The active_connection_id_limit transport parameter is the number of // connection IDs the peer will store. This limit includes the connection ID // used during the handshake, and the one sent in the preferred_address // transport parameter. // We currently don't send the preferred_address transport parameter, // so we can issue (limit - 1) connection IDs. for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ { if err := m.issueNewConnID(); err != nil { return err } } return nil } func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID, expiry monotime.Time) error { if seq > m.highestSeq { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), } } connID, ok := m.activeSrcConnIDs[seq] // We might already have deleted this connection ID, if this is a duplicate frame. if !ok { return nil } if connID == sentWithDestConnID { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), } } m.queueConnIDForRetiring(connID, expiry) delete(m.activeSrcConnIDs, seq) // Don't issue a replacement for the initial connection ID. if seq == 0 { return nil } return m.issueNewConnID() } func (m *connIDGenerator) queueConnIDForRetiring(connID protocol.ConnectionID, expiry monotime.Time) { idx := slices.IndexFunc(m.connIDsToRetire, func(c connIDToRetire) bool { return c.t.After(expiry) }) if idx == -1 { idx = len(m.connIDsToRetire) } m.connIDsToRetire = slices.Insert(m.connIDsToRetire, idx, connIDToRetire{t: expiry, connID: connID}) } func (m *connIDGenerator) issueNewConnID() error { connID, err := m.generator.GenerateConnectionID() if err != nil { return err } m.activeSrcConnIDs[m.highestSeq+1] = connID m.connRunners.AddConnectionID(connID) m.queueControlFrame(&wire.NewConnectionIDFrame{ SequenceNumber: m.highestSeq + 1, ConnectionID: connID, StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID), }) m.highestSeq++ return nil } func (m *connIDGenerator) SetHandshakeComplete(connIDExpiry monotime.Time) { if m.initialClientDestConnID != nil { m.queueConnIDForRetiring(*m.initialClientDestConnID, connIDExpiry) m.initialClientDestConnID = nil } } func (m *connIDGenerator) RemoveRetiredConnIDs(now monotime.Time) { if len(m.connIDsToRetire) == 0 { return } for _, c := range m.connIDsToRetire { if c.t.After(now) { break } m.connRunners.RemoveConnectionID(c.connID) m.connIDsToRetire = m.connIDsToRetire[1:] } } func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { m.connRunners.RemoveConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { m.connRunners.RemoveConnectionID(connID) } for _, c := range m.connIDsToRetire { m.connRunners.RemoveConnectionID(c.connID) } } func (m *connIDGenerator) ReplaceWithClosed(connClose []byte, expiry time.Duration) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+len(m.connIDsToRetire)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, *m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } for _, c := range m.connIDsToRetire { connIDs = append(connIDs, c.connID) } m.connRunners.ReplaceWithClosed(connIDs, connClose, expiry) } func (m *connIDGenerator) AddConnRunner(runner connRunner, r connRunnerCallbacks) { // The transport might have already been added earlier. // This happens if the application migrates back to and old path. if _, ok := m.connRunners[runner]; ok { return } m.connRunners[runner] = r if m.initialClientDestConnID != nil { r.AddConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { r.AddConnectionID(connID) } } quic-go-0.59.0/conn_id_generator_test.go000066400000000000000000000313251513066070600202260ustar00rootroot00000000000000package quic import ( "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestConnIDGeneratorIssueAndRetire(t *testing.T) { t.Run("with initial client destination connection ID", func(t *testing.T) { testConnIDGeneratorIssueAndRetire(t, true) }) t.Run("without initial client destination connection ID", func(t *testing.T) { testConnIDGeneratorIssueAndRetire(t, false) }) } func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID bool) { var ( added []protocol.ConnectionID removed []protocol.ConnectionID ) var queuedFrames []wire.Frame sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}) var initialClientDestConnID *protocol.ConnectionID if hasInitialClientDestConnID { connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) initialClientDestConnID = &connID } g := newConnIDGenerator( &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, sr, connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, ReplaceWithClosed: func([]protocol.ConnectionID, []byte, time.Duration) {}, }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) require.Empty(t, added) require.NoError(t, g.SetMaxActiveConnIDs(4)) require.Len(t, added, 3) require.Len(t, queuedFrames, 3) require.Empty(t, removed) connIDs := make(map[uint64]protocol.ConnectionID) // connection IDs 1, 2 and 3 were issued for i, f := range queuedFrames { ncid := f.(*wire.NewConnectionIDFrame) require.EqualValues(t, i+1, ncid.SequenceNumber) require.Equal(t, ncid.ConnectionID, added[i]) require.Equal(t, ncid.StatelessResetToken, sr.GetStatelessResetToken(ncid.ConnectionID)) connIDs[ncid.SequenceNumber] = ncid.ConnectionID } // completing the handshake retires the initial client destination connection ID added = added[:0] queuedFrames = queuedFrames[:0] now := monotime.Now() g.SetHandshakeComplete(now) require.Empty(t, added) require.Empty(t, queuedFrames) require.Empty(t, removed) g.RemoveRetiredConnIDs(now) if hasInitialClientDestConnID { require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, removed) removed = removed[:0] } else { require.Empty(t, removed) } // it's invalid to retire a connection ID that hasn't been issued yet err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now()) require.ErrorIs(t, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, err) require.ErrorContains(t, err, "retired connection ID 4 (highest issued: 3)") // it's invalid to retire a connection ID in a packet that uses that connection ID err = g.Retire(3, connIDs[3], monotime.Now()) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) require.ErrorContains(t, err, "was used as the Destination Connection ID on this packet") // retiring a connection ID makes us issue a new one require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now())) g.RemoveRetiredConnIDs(monotime.Now()) require.Equal(t, []protocol.ConnectionID{connIDs[2]}, removed) require.Len(t, queuedFrames, 1) require.EqualValues(t, 4, queuedFrames[0].(*wire.NewConnectionIDFrame).SequenceNumber) queuedFrames = queuedFrames[:0] removed = removed[:0] // duplicate retirements don't do anything require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now())) g.RemoveRetiredConnIDs(monotime.Now()) require.Empty(t, queuedFrames) require.Empty(t, removed) } func TestConnIDGeneratorRetiring(t *testing.T) { initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) var added, removed []protocol.ConnectionID g := newConnIDGenerator( &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), &initialConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, ReplaceWithClosed: func([]protocol.ConnectionID, []byte, time.Duration) {}, }, func(f wire.Frame) {}, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) require.NoError(t, g.SetMaxActiveConnIDs(6)) require.Empty(t, removed) require.Len(t, added, 5) now := monotime.Now() retirements := map[protocol.ConnectionID]monotime.Time{} t1 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond) retirements[initialConnID] = t1 g.SetHandshakeComplete(t1) for i := range 5 { t2 := now.Add(time.Duration(rand.IntN(1000)) * time.Millisecond) require.NoError(t, g.Retire(uint64(i+1), protocol.ParseConnectionID([]byte{9, 9, 9, 9}), t2)) retirements[added[i]] = t2 if rand.IntN(2) == 0 { now = now.Add(time.Duration(rand.IntN(500)) * time.Millisecond) g.RemoveRetiredConnIDs(now) for _, r := range removed { require.Contains(t, retirements, r) require.LessOrEqual(t, retirements[r], now) delete(retirements, r) } removed = removed[:0] for _, r := range retirements { require.Greater(t, r, now) } } } } func TestConnIDGeneratorRemoveAll(t *testing.T) { t.Run("with initial client destination connection ID", func(t *testing.T) { testConnIDGeneratorRemoveAll(t, true) }) t.Run("without initial client destination connection ID", func(t *testing.T) { testConnIDGeneratorRemoveAll(t, false) }) } func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) { var initialClientDestConnID *protocol.ConnectionID if hasInitialClientDestConnID { connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) initialClientDestConnID = &connID } var ( added []protocol.ConnectionID removed []protocol.ConnectionID ) g := newConnIDGenerator( &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { removed = append(removed, c) }, ReplaceWithClosed: func([]protocol.ConnectionID, []byte, time.Duration) {}, }, func(f wire.Frame) {}, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) require.NoError(t, g.SetMaxActiveConnIDs(1000)) require.Len(t, added, protocol.MaxIssuedConnectionIDs-1) g.RemoveAll() if hasInitialClientDestConnID { require.Len(t, removed, protocol.MaxIssuedConnectionIDs+1) require.Contains(t, removed, *initialClientDestConnID) } else { require.Len(t, removed, protocol.MaxIssuedConnectionIDs) } for _, id := range added { require.Contains(t, removed, id) } require.Contains(t, removed, protocol.ParseConnectionID([]byte{1, 1, 1, 1})) } func TestConnIDGeneratorReplaceWithClosed(t *testing.T) { t.Run("with initial client destination connection ID", func(t *testing.T) { testConnIDGeneratorReplaceWithClosed(t, true) }) t.Run("without initial client destination connection ID", func(t *testing.T) { testConnIDGeneratorReplaceWithClosed(t, false) }) } func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConnID bool) { var initialClientDestConnID *protocol.ConnectionID if hasInitialClientDestConnID { connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) initialClientDestConnID = &connID } var ( added []protocol.ConnectionID replaced []protocol.ConnectionID replacedWith []byte ) g := newConnIDGenerator( &packetHandlerMap{}, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), initialClientDestConnID, newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}), connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { added = append(added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, b []byte, _ time.Duration) { replaced = connIDs replacedWith = b }, }, func(f wire.Frame) {}, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) require.NoError(t, g.SetMaxActiveConnIDs(1000)) require.Len(t, added, protocol.MaxIssuedConnectionIDs-1) // Retire two of these connection ID. // This makes us issue two more connection IDs. require.NoError(t, g.Retire(3, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), monotime.Now())) require.NoError(t, g.Retire(4, protocol.ParseConnectionID([]byte{1, 1, 1, 1}), monotime.Now())) require.Len(t, added, protocol.MaxIssuedConnectionIDs+1) g.ReplaceWithClosed([]byte("foobar"), time.Second) if hasInitialClientDestConnID { require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+3) require.Contains(t, replaced, *initialClientDestConnID) } else { require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+2) } for _, id := range added { require.Contains(t, replaced, id) } require.Contains(t, replaced, protocol.ParseConnectionID([]byte{1, 1, 1, 1})) require.Equal(t, []byte("foobar"), replacedWith) } func TestConnIDGeneratorAddConnRunner(t *testing.T) { initialConnID := protocol.ParseConnectionID([]byte{1, 1, 1, 1}) clientDestConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2}) type connIDTracker struct { added, removed, replaced []protocol.ConnectionID } var tracker1, tracker2, tracker3 connIDTracker runner1 := connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte, _ time.Duration) { tracker1.replaced = append(tracker1.replaced, connIDs...) }, } runner2 := connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { tracker2.added = append(tracker2.added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { tracker2.removed = append(tracker2.removed, c) }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte, _ time.Duration) { tracker2.replaced = append(tracker2.replaced, connIDs...) }, } runner3 := connRunnerCallbacks{ AddConnectionID: func(c protocol.ConnectionID) { tracker3.added = append(tracker3.added, c) }, RemoveConnectionID: func(c protocol.ConnectionID) { tracker3.removed = append(tracker3.removed, c) }, ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte, _ time.Duration) { tracker3.replaced = append(tracker3.replaced, connIDs...) }, } sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}) var queuedFrames []wire.Frame tr := &packetHandlerMap{} g := newConnIDGenerator( tr, initialConnID, &clientDestConnID, sr, runner1, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, &protocol.DefaultConnectionIDGenerator{ConnLen: 5}, ) require.NoError(t, g.SetMaxActiveConnIDs(3)) require.Len(t, tracker1.added, 2) // add the second runner - it should get all existing connection IDs g.AddConnRunner(&packetHandlerMap{}, runner2) require.Len(t, tracker1.added, 2) // unchanged require.Len(t, tracker2.added, 4) require.Contains(t, tracker2.added, initialConnID) require.Contains(t, tracker2.added, clientDestConnID) require.Contains(t, tracker2.added, tracker1.added[0]) require.Contains(t, tracker2.added, tracker1.added[1]) // adding the same transport again doesn't do anything trCopy := tr g.AddConnRunner(trCopy, runner3) require.Empty(t, tracker3.added) var connIDToRetire protocol.ConnectionID var seqToRetire uint64 ncid := queuedFrames[0].(*wire.NewConnectionIDFrame) connIDToRetire = ncid.ConnectionID seqToRetire = ncid.SequenceNumber require.NoError(t, g.Retire(seqToRetire, protocol.ParseConnectionID([]byte{3, 3, 3, 3}), monotime.Now())) g.RemoveRetiredConnIDs(monotime.Now()) require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker1.removed) require.Equal(t, []protocol.ConnectionID{connIDToRetire}, tracker2.removed) tracker1.removed = nil tracker2.removed = nil g.SetHandshakeComplete(monotime.Now()) g.RemoveRetiredConnIDs(monotime.Now()) require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker1.removed) require.Equal(t, []protocol.ConnectionID{clientDestConnID}, tracker2.removed) g.ReplaceWithClosed([]byte("connection closed"), time.Second) require.True(t, len(tracker1.replaced) > 0) require.Equal(t, tracker1.replaced, tracker2.replaced) tracker1.removed = nil tracker2.removed = nil g.RemoveAll() require.NotEmpty(t, tracker1.removed) require.Equal(t, tracker1.removed, tracker2.removed) } quic-go-0.59.0/conn_id_manager.go000066400000000000000000000234411513066070600166130ustar00rootroot00000000000000package quic import ( "fmt" "slices" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) type newConnID struct { SequenceNumber uint64 ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } type connIDManager struct { queue []newConnID highestProbingID uint64 pathProbing map[pathID]newConnID // initialized lazily handshakeComplete bool activeSequenceNumber uint64 highestRetired uint64 activeConnectionID protocol.ConnectionID activeStatelessResetToken *protocol.StatelessResetToken // We change the connection ID after sending on average // protocol.PacketsPerConnectionID packets. The actual value is randomized // hide the packet loss rate from on-path observers. rand utils.Rand packetsSinceLastChange uint32 packetsPerConnectionID uint32 addStatelessResetToken func(protocol.StatelessResetToken) removeStatelessResetToken func(protocol.StatelessResetToken) queueControlFrame func(wire.Frame) closed bool } func newConnIDManager( initialDestConnID protocol.ConnectionID, addStatelessResetToken func(protocol.StatelessResetToken), removeStatelessResetToken func(protocol.StatelessResetToken), queueControlFrame func(wire.Frame), ) *connIDManager { return &connIDManager{ activeConnectionID: initialDestConnID, addStatelessResetToken: addStatelessResetToken, removeStatelessResetToken: removeStatelessResetToken, queueControlFrame: queueControlFrame, queue: make([]newConnID, 0, protocol.MaxActiveConnectionIDs), } } func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { return h.addConnectionID(1, connID, resetToken) } func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { if err := h.add(f); err != nil { return err } if len(h.queue) >= protocol.MaxActiveConnectionIDs { return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} } return nil } func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { if h.activeConnectionID.Len() == 0 { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use", } } // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: f.SequenceNumber, }) return nil } if f.RetirePriorTo != 0 && h.pathProbing != nil { for id, entry := range h.pathProbing { if entry.SequenceNumber < f.RetirePriorTo { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: entry.SequenceNumber, }) h.removeStatelessResetToken(entry.StatelessResetToken) delete(h.pathProbing, id) } } } // Retire elements in the queue. // Doesn't retire the active connection ID. if f.RetirePriorTo > h.highestRetired { var newQueue []newConnID for _, entry := range h.queue { if entry.SequenceNumber >= f.RetirePriorTo { newQueue = append(newQueue, entry) } else { h.queueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: entry.SequenceNumber}) } } h.queue = newQueue h.highestRetired = f.RetirePriorTo } if f.SequenceNumber == h.activeSequenceNumber { return nil } if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil { return err } // Retire the active connection ID, if necessary. if h.activeSequenceNumber < f.RetirePriorTo { // The queue is guaranteed to have at least one element at this point. h.updateConnectionID() } return nil } func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { // fast path: add to the end of the queue if len(h.queue) == 0 || h.queue[len(h.queue)-1].SequenceNumber < seq { h.queue = append(h.queue, newConnID{ SequenceNumber: seq, ConnectionID: connID, StatelessResetToken: resetToken, }) return nil } // slow path: insert in the middle for i, entry := range h.queue { if entry.SequenceNumber == seq { if entry.ConnectionID != connID { return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) } if entry.StatelessResetToken != resetToken { return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq) } return nil } // insert at the correct position to maintain sorted order if entry.SequenceNumber > seq { h.queue = slices.Insert(h.queue, i, newConnID{ SequenceNumber: seq, ConnectionID: connID, StatelessResetToken: resetToken, }) return nil } } return nil // unreachable } func (h *connIDManager) updateConnectionID() { h.assertNotClosed() h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, }) h.highestRetired = max(h.highestRetired, h.activeSequenceNumber) if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } front := h.queue[0] h.queue = h.queue[1:] h.activeSequenceNumber = front.SequenceNumber h.activeConnectionID = front.ConnectionID h.activeStatelessResetToken = &front.StatelessResetToken h.packetsSinceLastChange = 0 h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID)) h.addStatelessResetToken(*h.activeStatelessResetToken) } func (h *connIDManager) Close() { h.closed = true if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } if h.pathProbing != nil { for _, entry := range h.pathProbing { h.removeStatelessResetToken(entry.StatelessResetToken) } } } // is called when the server performs a Retry // and when the server changes the connection ID in the first Initial sent func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { if h.activeSequenceNumber != 0 { panic("expected first connection ID to have sequence number 0") } h.activeConnectionID = newConnID } // is called when the server provides a stateless reset token in the transport parameters func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) { h.assertNotClosed() if h.activeSequenceNumber != 0 { panic("expected first connection ID to have sequence number 0") } h.activeStatelessResetToken = &token h.addStatelessResetToken(token) } func (h *connIDManager) SentPacket() { h.packetsSinceLastChange++ } func (h *connIDManager) shouldUpdateConnID() bool { if !h.handshakeComplete { return false } // initiate the first change as early as possible (after handshake completion) if len(h.queue) > 0 && h.activeSequenceNumber == 0 { return true } // For later changes, only change if // 1. The queue of connection IDs is filled more than 50%. // 2. We sent at least PacketsPerConnectionID packets return 2*len(h.queue) >= protocol.MaxActiveConnectionIDs && h.packetsSinceLastChange >= h.packetsPerConnectionID } func (h *connIDManager) Get() protocol.ConnectionID { h.assertNotClosed() if h.shouldUpdateConnID() { h.updateConnectionID() } return h.activeConnectionID } func (h *connIDManager) SetHandshakeComplete() { h.handshakeComplete = true } // GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one). // Once a connection ID is allocated for a path, it cannot be used for a different path. // When called with the same pathID, it will return the same connection ID, // unless the peer requested that this connection ID be retired. func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) { h.assertNotClosed() // if we're using zero-length connection IDs, we don't need to change the connection ID if h.activeConnectionID.Len() == 0 { return protocol.ConnectionID{}, true } if h.pathProbing == nil { h.pathProbing = make(map[pathID]newConnID) } entry, ok := h.pathProbing[id] if ok { return entry.ConnectionID, true } if len(h.queue) == 0 { return protocol.ConnectionID{}, false } front := h.queue[0] h.queue = h.queue[1:] h.pathProbing[id] = front h.highestProbingID = front.SequenceNumber h.addStatelessResetToken(front.StatelessResetToken) return front.ConnectionID, true } func (h *connIDManager) RetireConnIDForPath(pathID pathID) { h.assertNotClosed() // if we're using zero-length connection IDs, we don't need to change the connection ID if h.activeConnectionID.Len() == 0 { return } entry, ok := h.pathProbing[pathID] if !ok { return } h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: entry.SequenceNumber, }) h.removeStatelessResetToken(entry.StatelessResetToken) delete(h.pathProbing, pathID) } func (h *connIDManager) IsActiveStatelessResetToken(token protocol.StatelessResetToken) bool { if h.activeStatelessResetToken != nil { if *h.activeStatelessResetToken == token { return true } } if h.pathProbing != nil { for _, entry := range h.pathProbing { if entry.StatelessResetToken == token { return true } } } return false } // Using the connIDManager after it has been closed can have disastrous effects: // If the connection ID is rotated, a new entry would be inserted into the packet handler map, // leading to a memory leak of the connection struct. // See https://github.com/quic-go/quic-go/pull/4852 for more details. func (h *connIDManager) assertNotClosed() { if h.closed { panic("connection ID manager is closed") } } quic-go-0.59.0/conn_id_manager_test.go000066400000000000000000000425151513066070600176550ustar00rootroot00000000000000package quic import ( "crypto/rand" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestConnIDManagerInitialConnID(t *testing.T) { m := newConnIDManager(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), nil, nil, nil) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get()) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get()) m.ChangeInitialConnID(protocol.ParseConnectionID([]byte{5, 6, 7, 8})) require.Equal(t, protocol.ParseConnectionID([]byte{5, 6, 7, 8}), m.Get()) } func TestConnIDManagerAddConnIDs(t *testing.T) { m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(protocol.StatelessResetToken) {}, func(protocol.StatelessResetToken) {}, func(wire.Frame) {}, ) f1 := &wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } f2 := &wire.NewConnectionIDFrame{ SequenceNumber: 2, ConnectionID: protocol.ParseConnectionID([]byte{0xba, 0xad, 0xf0, 0x0d}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } require.NoError(t, m.Add(f2)) require.NoError(t, m.Add(f1)) // receiving reordered frames is fine require.NoError(t, m.Add(f2)) // receiving a duplicate is fine require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get()) m.updateConnectionID() require.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), m.Get()) m.updateConnectionID() require.Equal(t, protocol.ParseConnectionID([]byte{0xba, 0xad, 0xf0, 0x0d}), m.Get()) require.NoError(t, m.Add(f2)) // receiving a duplicate for the current connection ID is fine as well require.Equal(t, protocol.ParseConnectionID([]byte{0xba, 0xad, 0xf0, 0x0d}), m.Get()) // receiving mismatching connection IDs is not fine require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 3, ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), // mismatching connection ID StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, })) require.EqualError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 3, ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), // mismatching connection ID StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, }), "received conflicting connection IDs for sequence number 3") // receiving mismatching stateless reset tokens is not fine either require.EqualError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 3, ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0}, }), "received conflicting stateless reset tokens for sequence number 3") } func TestConnIDManagerLimit(t *testing.T) { m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(protocol.StatelessResetToken) {}, func(protocol.StatelessResetToken) {}, func(f wire.Frame) {}, ) for i := uint8(1); i < protocol.MaxActiveConnectionIDs; i++ { require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(i), ConnectionID: protocol.ParseConnectionID([]byte{i, i, i, i}), StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })) } require.Equal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(9999), ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, })) } func TestConnIDManagerRetiringConnectionIDs(t *testing.T) { var frameQueue []wire.Frame m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(protocol.StatelessResetToken) {}, func(protocol.StatelessResetToken) {}, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, ) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), })) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 13, ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), })) require.Empty(t, frameQueue) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ RetirePriorTo: 14, SequenceNumber: 17, ConnectionID: protocol.ParseConnectionID([]byte{3, 4, 5, 6}), })) require.Equal(t, []wire.Frame{ &wire.RetireConnectionIDFrame{SequenceNumber: 10}, &wire.RetireConnectionIDFrame{SequenceNumber: 13}, &wire.RetireConnectionIDFrame{SequenceNumber: 0}, }, frameQueue) require.Equal(t, protocol.ParseConnectionID([]byte{3, 4, 5, 6}), m.Get()) frameQueue = nil // a reordered connection ID is immediately retired require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 12, ConnectionID: protocol.ParseConnectionID([]byte{5, 6, 7, 8}), })) require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 12}}, frameQueue) require.Equal(t, protocol.ParseConnectionID([]byte{3, 4, 5, 6}), m.Get()) } func TestConnIDManagerHandshakeCompletion(t *testing.T) { var frameQueue []wire.Frame var addedTokens, removedTokens []protocol.StatelessResetToken m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) }, func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, ) m.SetStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, addedTokens) require.Empty(t, removedTokens) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get()) m.SetHandshakeComplete() require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), m.Get()) require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 0}}, frameQueue) require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, removedTokens) } func TestConnIDManagerConnIDRotation(t *testing.T) { toToken := func(connID protocol.ConnectionID) protocol.StatelessResetToken { var token protocol.StatelessResetToken copy(token[:], connID.Bytes()) copy(token[connID.Len():], connID.Bytes()) return token } var frameQueue []wire.Frame var addedTokens, removedTokens []protocol.StatelessResetToken m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) }, func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, ) // the first connection ID is used as soon as the handshake is complete m.SetHandshakeComplete() firstConnID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: firstConnID, StatelessResetToken: toToken(protocol.ParseConnectionID([]byte{4, 3, 2, 1})), })) require.Equal(t, firstConnID, m.Get()) frameQueue = nil require.True(t, m.IsActiveStatelessResetToken(toToken(firstConnID))) require.Equal(t, addedTokens, []protocol.StatelessResetToken{toToken(firstConnID)}) addedTokens = addedTokens[:0] // Note that we're missing the connection ID with sequence number 2. // It will be received later. var queuedConnIDs []protocol.ConnectionID for i := 0; i < protocol.MaxActiveConnectionIDs-1; i++ { b := make([]byte, 4) rand.Read(b) connID := protocol.ParseConnectionID(b) queuedConnIDs = append(queuedConnIDs, connID) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(3 + i), ConnectionID: connID, StatelessResetToken: toToken(connID), })) require.False(t, m.IsActiveStatelessResetToken(toToken(connID))) } var counter int for { require.Empty(t, frameQueue) m.SentPacket() counter++ if connID := m.Get(); connID != firstConnID { require.Equal(t, queuedConnIDs[0], m.Get()) require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 1}}, frameQueue) require.Equal(t, removedTokens, []protocol.StatelessResetToken{toToken(firstConnID)}) require.Equal(t, addedTokens, []protocol.StatelessResetToken{toToken(connID)}) addedTokens = addedTokens[:0] removedTokens = removedTokens[:0] require.True(t, m.IsActiveStatelessResetToken(toToken(connID))) require.False(t, m.IsActiveStatelessResetToken(toToken(firstConnID))) break } require.True(t, m.IsActiveStatelessResetToken(toToken(firstConnID))) require.Empty(t, addedTokens) } require.GreaterOrEqual(t, counter, protocol.PacketsPerConnectionID/2) require.LessOrEqual(t, counter, protocol.PacketsPerConnectionID*3/2) frameQueue = nil // now receive connection ID 2 require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 2, ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), })) require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 2}}, frameQueue) } func TestConnIDManagerPathMigration(t *testing.T) { var frameQueue []wire.Frame var addedTokens, removedTokens []protocol.StatelessResetToken m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) }, func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, ) // no connection ID available yet _, ok := m.GetConnIDForPath(1) require.False(t, ok) // add two connection IDs require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), StatelessResetToken: protocol.StatelessResetToken{4, 3, 2, 1, 4, 3, 2, 1}, })) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 2, ConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2}), StatelessResetToken: protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2}, })) connID, ok := m.GetConnIDForPath(1) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), connID) require.Equal(t, []protocol.StatelessResetToken{{4, 3, 2, 1, 4, 3, 2, 1}}, addedTokens) require.Empty(t, removedTokens) addedTokens = addedTokens[:0] require.False(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2})) connID, ok = m.GetConnIDForPath(2) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2}), connID) require.Equal(t, []protocol.StatelessResetToken{{5, 4, 3, 2, 5, 4, 3, 2}}, addedTokens) require.Empty(t, removedTokens) require.True(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2})) addedTokens = addedTokens[:0] // asking for the connection for path 1 again returns the same connection ID connID, ok = m.GetConnIDForPath(1) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), connID) require.Empty(t, addedTokens) // if the connection ID is retired, the path will use another connection ID require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 3, RetirePriorTo: 2, ConnectionID: protocol.ParseConnectionID([]byte{6, 5, 4, 3}), StatelessResetToken: protocol.StatelessResetToken{6, 5, 4, 3, 6, 5, 4, 3}, })) require.Len(t, frameQueue, 2) require.Equal(t, []protocol.StatelessResetToken{{4, 3, 2, 1, 4, 3, 2, 1}}, removedTokens) frameQueue = nil removedTokens = removedTokens[:0] require.Equal(t, protocol.ParseConnectionID([]byte{6, 5, 4, 3}), m.Get()) require.Equal(t, []protocol.StatelessResetToken{{6, 5, 4, 3, 6, 5, 4, 3}}, addedTokens) require.Empty(t, removedTokens) addedTokens = addedTokens[:0] // the connection ID is not used for new paths _, ok = m.GetConnIDForPath(3) require.False(t, ok) // Manually retiring the connection ID does nothing. // Path 1 doesn't have a connection ID anymore. m.RetireConnIDForPath(1) require.Empty(t, frameQueue) _, ok = m.GetConnIDForPath(1) require.False(t, ok) require.Empty(t, removedTokens) // only after a new connection ID is added, it will be used for path 1 require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 4, ConnectionID: protocol.ParseConnectionID([]byte{7, 6, 5, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13}, })) connID, ok = m.GetConnIDForPath(1) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{7, 6, 5, 4}), connID) require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, addedTokens) require.Empty(t, removedTokens) require.True(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13})) // a RETIRE_CONNECTION_ID frame for path 1 is queued when retiring the connection ID m.RetireConnIDForPath(1) require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 4}}, frameQueue) require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, removedTokens) removedTokens = removedTokens[:0] require.False(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13})) m.Close() require.Equal(t, []protocol.StatelessResetToken{ {6, 5, 4, 3, 6, 5, 4, 3}, // currently active connection ID {5, 4, 3, 2, 5, 4, 3, 2}, // path 2 }, removedTokens) } func TestConnIDManagerZeroLengthConnectionID(t *testing.T) { m := newConnIDManager( protocol.ConnectionID{}, func(protocol.StatelessResetToken) {}, func(protocol.StatelessResetToken) {}, func(f wire.Frame) {}, ) require.Equal(t, protocol.ConnectionID{}, m.Get()) for range 5 * protocol.PacketsPerConnectionID { m.SentPacket() require.Equal(t, protocol.ConnectionID{}, m.Get()) } // for path probing, we don't need to change the connection ID for id := pathID(1); id < 10; id++ { connID, ok := m.GetConnIDForPath(id) require.True(t, ok) require.Equal(t, protocol.ConnectionID{}, connID) } // retiring a connection ID for a path is also a no-op for id := pathID(1); id < 20; id++ { m.RetireConnIDForPath(id) } require.ErrorIs(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{}, StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) } func TestConnIDManagerClose(t *testing.T) { var addedTokens, removedTokens []protocol.StatelessResetToken m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) }, func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, func(f wire.Frame) {}, ) m.SetStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, addedTokens) require.Empty(t, removedTokens) m.Close() require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, removedTokens) require.Panics(t, func() { m.Get() }) require.Panics(t, func() { m.SetStatelessResetToken(protocol.StatelessResetToken{}) }) } func BenchmarkConnIDManagerReordered(b *testing.B) { benchmarkConnIDManager(b, true) } func BenchmarkConnIDManagerInOrder(b *testing.B) { benchmarkConnIDManager(b, false) } func benchmarkConnIDManager(b *testing.B, reordered bool) { m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func(protocol.StatelessResetToken) {}, func(protocol.StatelessResetToken) {}, func(f wire.Frame) {}, ) connIDs := make([]protocol.ConnectionID, 0, protocol.MaxActiveConnectionIDs) statelessResetTokens := make([]protocol.StatelessResetToken, 0, protocol.MaxActiveConnectionIDs) for range protocol.MaxActiveConnectionIDs { b := make([]byte, 8) rand.Read(b) connIDs = append(connIDs, protocol.ParseConnectionID(b)) var statelessResetToken protocol.StatelessResetToken rand.Read(statelessResetToken[:]) statelessResetTokens = append(statelessResetTokens, statelessResetToken) } // 1 -> 3 // 2 -> 1 // 3 -> 2 // 4 -> 4 offsets := []int{2, -1, -1, 0} b.ResetTimer() for i := range b.N { seq := i if reordered { seq += offsets[i%len(offsets)] } m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(seq), ConnectionID: connIDs[i%len(connIDs)], StatelessResetToken: statelessResetTokens[i%len(statelessResetTokens)], }) if i > protocol.MaxActiveConnectionIDs-2 { m.updateConnectionID() } } } quic-go-0.59.0/conn_wrapped_test.go000066400000000000000000000026331513066070600172260ustar00rootroot00000000000000package quic import "context" func (c *wrappedConn) run() error { if c.testHooks == nil { return c.Conn.run() } if c.testHooks.run != nil { return c.testHooks.run() } return nil } func (c *wrappedConn) earlyConnReady() <-chan struct{} { if c.testHooks == nil { return c.Conn.earlyConnReady() } if c.testHooks.earlyConnReady != nil { return c.testHooks.earlyConnReady() } return nil } func (c *wrappedConn) Context() context.Context { if c.testHooks == nil { return c.Conn.Context() } if c.testHooks.context != nil { return c.testHooks.context() } return context.Background() } func (c *wrappedConn) HandshakeComplete() <-chan struct{} { if c.testHooks == nil { return c.Conn.HandshakeComplete() } if c.testHooks.handshakeComplete != nil { return c.testHooks.handshakeComplete() } return nil } func (c *wrappedConn) closeWithTransportError(code TransportErrorCode) { if c.testHooks == nil { c.Conn.closeWithTransportError(code) return } if c.testHooks.closeWithTransportError != nil { c.testHooks.closeWithTransportError(code) } } func (c *wrappedConn) destroy(e error) { if c.testHooks == nil { c.Conn.destroy(e) return } if c.testHooks.destroy != nil { c.testHooks.destroy(e) } } func (c *wrappedConn) handlePacket(p receivedPacket) { if c.testHooks == nil { c.Conn.handlePacket(p) return } if c.testHooks.handlePacket != nil { c.testHooks.handlePacket(p) } } quic-go-0.59.0/connection.go000066400000000000000000003153171513066070600156550ustar00rootroot00000000000000package quic import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "reflect" "slices" "sync" "sync/atomic" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) type unpacker interface { UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) UnpackShortHeader(rcvTime monotime.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } type cryptoStreamHandler interface { StartHandshake(context.Context) error ChangeConnectionID(protocol.ConnectionID) SetLargest1RTTAcked(protocol.PacketNumber) error SetHandshakeConfirmed() GetSessionTicket() ([]byte, error) NextEvent() handshake.Event DiscardInitialKeys() HandleMessage([]byte, protocol.EncryptionLevel) error io.Closer ConnectionState() handshake.ConnectionState } type receivedPacket struct { buffer *packetBuffer remoteAddr net.Addr rcvTime monotime.Time data []byte ecn protocol.ECN info packetInfo // only valid if the contained IP address is valid } type receivedPacketWithDatagramID struct { receivedPacket datagramID qlog.DatagramID } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } func (p *receivedPacket) Clone() *receivedPacket { return &receivedPacket{ remoteAddr: p.remoteAddr, rcvTime: p.rcvTime, data: p.data, buffer: p.buffer, ecn: p.ecn, info: p.info, } } type connRunner interface { Add(protocol.ConnectionID, packetHandler) bool Remove(protocol.ConnectionID) ReplaceWithClosed([]protocol.ConnectionID, []byte, time.Duration) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } type closeError struct { err error immediate bool } type errCloseForRecreating struct { nextPacketNumber protocol.PacketNumber nextVersion protocol.Version } func (e *errCloseForRecreating) Error() string { return "closing connection in order to recreate it" } var deadlineSendImmediately = monotime.Time(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine type blockMode uint8 const ( // blockModeNone means that the connection is not blocked. blockModeNone blockMode = iota // blockModeCongestionLimited means that the connection is congestion limited. // In that case, we can still send acknowledgments and PTO probe packets. blockModeCongestionLimited // blockModeHardBlocked means that no packet can be sent, under no circumstances. This can happen when: // * the send queue is full // * the SentPacketHandler returns SendNone, e.g. when we are tracking the maximum number of packets // In that case, the timer will be set to the idle timeout. blockModeHardBlocked ) // A Conn is a QUIC connection between two peers. // Calls to the connection (and to streams) can return the following types of errors: // - [ApplicationError]: for errors triggered by the application running on top of QUIC // - [TransportError]: for errors triggered by the QUIC transport (in many cases a misbehaving peer) // - [IdleTimeoutError]: when the peer goes away unexpectedly (this is a [net.Error] timeout error) // - [HandshakeTimeoutError]: when the cryptographic handshake takes too long (this is a [net.Error] timeout error) // - [StatelessResetError]: when we receive a stateless reset // - [VersionNegotiationError]: returned by the client, when there's no version overlap between the peers type Conn struct { // Destination connection ID used during the handshake. // Used to check source connection ID on incoming packets. handshakeDestConnID protocol.ConnectionID // Set for the client. Destination connection ID used on the first Initial sent. origDestConnID protocol.ConnectionID retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed) srcConnIDLen int perspective protocol.Perspective version protocol.Version config *Config conn sendConn sendQueue sender // lazily initialzed: most connections never migrate pathManager *pathManager largestRcvdAppData protocol.PacketNumber pathManagerOutgoing atomic.Pointer[pathManagerOutgoing] streamsMap *streamsMap connIDManager *connIDManager connIDGenerator *connIDGenerator rttStats *utils.RTTStats connStats utils.ConnectionStats cryptoStreamManager *cryptoStreamManager sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler retransmissionQueue *retransmissionQueue framer *framer connFlowController flowcontrol.ConnectionFlowController tokenStoreKey string // only set for the client tokenGenerator *handshake.TokenGenerator // only set for the server unpacker unpacker frameParser wire.FrameParser packer packer mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received currentMTUEstimate atomic.Uint32 initialStream *initialCryptoStream handshakeStream *cryptoStream oneRTTStream *cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler notifyReceivedPacket chan struct{} sendingScheduled chan struct{} receivedPacketMx sync.Mutex receivedPackets ringbuffer.RingBuffer[receivedPacket] // closeChan is used to notify the run loop that it should terminate closeChan chan struct{} closeErr atomic.Pointer[closeError] ctx context.Context ctxCancel context.CancelCauseFunc handshakeCompleteChan chan struct{} undecryptablePackets []receivedPacketWithDatagramID // undecryptable packets, waiting for a change in encryption level undecryptablePacketsToProcess []receivedPacketWithDatagramID earlyConnReadyChan chan struct{} sentFirstPacket bool droppedInitialKeys bool handshakeComplete bool handshakeConfirmed bool receivedRetry bool versionNegotiated bool receivedFirstPacket bool blocked blockMode // the minimum of the max_idle_timeout values advertised by both endpoints idleTimeout time.Duration creationTime monotime.Time // The idle timeout is set based on the max of the time we received the last packet... lastPacketReceivedTime monotime.Time // ... and the time we sent a new ack-eliciting packet after receiving a packet. firstAckElicitingPacketAfterIdleSentTime monotime.Time // pacingDeadline is the time when the next packet should be sent pacingDeadline monotime.Time peerParams *wire.TransportParameters timer *time.Timer // keepAlivePingSent stores whether a keep alive PING is in flight. // It is reset as soon as we receive a packet from the peer. keepAlivePingSent bool keepAliveInterval time.Duration datagramQueue *datagramQueue connStateMutex sync.Mutex connState ConnectionState logID string qlogTrace qlogwriter.Trace qlogger qlogwriter.Recorder logger utils.Logger } var _ streamSender = &Conn{} type connTestHooks struct { run func() error earlyConnReady func() <-chan struct{} context func() context.Context handshakeComplete func() <-chan struct{} closeWithTransportError func(TransportErrorCode) destroy func(error) handlePacket func(receivedPacket) } type wrappedConn struct { testHooks *connTestHooks *Conn } var newConnection = func( ctx context.Context, ctxCancel context.CancelCauseFunc, conn sendConn, runner connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, statelessResetter *statelessResetter, conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, clientAddressValidated bool, rtt time.Duration, qlogTrace qlogwriter.Trace, logger utils.Logger, v protocol.Version, ) *wrappedConn { s := &Conn{ ctx: ctx, ctxCancel: ctxCancel, conn: conn, config: conf, handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, oneRTTStream: newCryptoStream(), perspective: protocol.PerspectiveServer, qlogTrace: qlogTrace, logger: logger, version: v, } if qlogTrace != nil { s.qlogger = qlogTrace.AddProducer() } if origDestConnID.Len() > 0 { s.logID = origDestConnID.String() } else { s.logID = destConnID.String() } s.connIDManager = newConnIDManager( destConnID, func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, runner.RemoveResetToken, s.queueControlFrame, ) s.connIDGenerator = newConnIDGenerator( runner, srcConnID, &clientDestConnID, statelessResetter, connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, ReplaceWithClosed: runner.ReplaceWithClosed, }, s.queueControlFrame, connIDGenerator, ) s.preSetup() s.rttStats.SetInitialRTT(rtt) s.sentPacketHandler = ackhandler.NewSentPacketHandler( 0, protocol.ByteCount(s.config.InitialPacketSize), s.rttStats, &s.connStats, clientAddressValidated, s.conn.capabilities().ECN, s.receivedPacketHandler.IgnorePacketsBelow, s.perspective, s.qlogger, s.logger, ) s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID) params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), MaxIdleTimeout: s.config.MaxIdleTimeout, MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), MaxAckDelay: protocol.MaxAckDelayInclGranularity, AckDelayExponent: protocol.AckDelayExponent, MaxUDPPayloadSize: protocol.MaxPacketBufferSize, StatelessResetToken: &statelessResetToken, OriginalDestinationConnectionID: origDestConnID, // For interoperability with quic-go versions before May 2023, this value must be set to a value // different from protocol.DefaultActiveConnectionIDLimit. // If set to the default value, it will be omitted from the transport parameters, which will make // old quic-go versions interpret it as 0, instead of the default value of 2. // See https://github.com/quic-go/quic-go/pull/3806. ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, RetrySourceConnectionID: retrySrcConnID, EnableResetStreamAt: conf.EnableStreamResetPartialDelivery, } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } if s.qlogger != nil { s.qlogTransportParameters(params, protocol.PerspectiveServer, false) } cs := handshake.NewCryptoSetupServer( clientDestConnID, conn.LocalAddr(), conn.RemoteAddr(), params, tlsConf, conf.Allow0RTT, s.rttStats, s.qlogger, logger, s.version, ) s.cryptoStreamHandler = cs s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, &s.receivedPacketHandler, s.datagramQueue, s.perspective) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, s.oneRTTStream) return &wrappedConn{Conn: s} } // declare this as a variable, such that we can it mock it in the tests var newClientConnection = func( ctx context.Context, conn sendConn, runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, connIDGenerator ConnectionIDGenerator, statelessResetter *statelessResetter, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, qlogTrace qlogwriter.Trace, logger utils.Logger, v protocol.Version, ) *wrappedConn { s := &Conn{ conn: conn, config: conf, origDestConnID: destConnID, handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), perspective: protocol.PerspectiveClient, logID: destConnID.String(), logger: logger, qlogTrace: qlogTrace, versionNegotiated: hasNegotiatedVersion, version: v, } if qlogTrace != nil { s.qlogger = qlogTrace.AddProducer() } if s.qlogger != nil { var srcAddr, destAddr *net.UDPAddr if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok { srcAddr = addr } if addr, ok := conn.RemoteAddr().(*net.UDPAddr); ok { destAddr = addr } s.qlogger.RecordEvent(startedConnectionEvent(srcAddr, destAddr)) } s.connIDManager = newConnIDManager( destConnID, func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, runner.RemoveResetToken, s.queueControlFrame, ) s.connIDGenerator = newConnIDGenerator( runner, srcConnID, nil, statelessResetter, connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) }, RemoveConnectionID: runner.Remove, ReplaceWithClosed: runner.ReplaceWithClosed, }, s.queueControlFrame, connIDGenerator, ) s.ctx, s.ctxCancel = context.WithCancelCause(ctx) s.preSetup() s.sentPacketHandler = ackhandler.NewSentPacketHandler( initialPacketNumber, protocol.ByteCount(s.config.InitialPacketSize), s.rttStats, &s.connStats, false, // has no effect s.conn.capabilities().ECN, s.receivedPacketHandler.IgnorePacketsBelow, s.perspective, s.qlogger, s.logger, ) s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), MaxIdleTimeout: s.config.MaxIdleTimeout, MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), MaxAckDelay: protocol.MaxAckDelayInclGranularity, MaxUDPPayloadSize: protocol.MaxPacketBufferSize, AckDelayExponent: protocol.AckDelayExponent, // For interoperability with quic-go versions before May 2023, this value must be set to a value // different from protocol.DefaultActiveConnectionIDLimit. // If set to the default value, it will be omitted from the transport parameters, which will make // old quic-go versions interpret it as 0, instead of the default value of 2. // See https://github.com/quic-go/quic-go/pull/3806. ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, EnableResetStreamAt: conf.EnableStreamResetPartialDelivery, } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } if s.qlogger != nil { s.qlogTransportParameters(params, protocol.PerspectiveClient, false) } cs := handshake.NewCryptoSetupClient( destConnID, params, tlsConf, enable0RTT, s.rttStats, s.qlogger, logger, s.version, ) s.cryptoStreamHandler = cs s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, oneRTTStream) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, &s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { s.tokenStoreKey = tlsConf.ServerName } else { s.tokenStoreKey = conn.RemoteAddr().String() } if s.config.TokenStore != nil { if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { s.packer.SetToken(token.data) s.rttStats.SetInitialRTT(token.rtt) } } return &wrappedConn{Conn: s} } func (c *Conn) preSetup() { c.largestRcvdAppData = protocol.InvalidPacketNumber c.initialStream = newInitialCryptoStream(c.perspective == protocol.PerspectiveClient) c.handshakeStream = newCryptoStream() c.sendQueue = newSendQueue(c.conn) c.retransmissionQueue = newRetransmissionQueue() c.frameParser = *wire.NewFrameParser( c.config.EnableDatagrams, c.config.EnableStreamResetPartialDelivery, false, // ACK_FREQUENCY is not supported yet ) c.rttStats = utils.NewRTTStats() c.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ByteCount(c.config.InitialConnectionReceiveWindow), protocol.ByteCount(c.config.MaxConnectionReceiveWindow), func(size protocol.ByteCount) bool { if c.config.AllowConnectionWindowIncrease == nil { return true } return c.config.AllowConnectionWindowIncrease(c, uint64(size)) }, c.rttStats, c.logger, ) c.earlyConnReadyChan = make(chan struct{}) c.streamsMap = newStreamsMap( c.ctx, c, c.queueControlFrame, c.newFlowController, uint64(c.config.MaxIncomingStreams), uint64(c.config.MaxIncomingUniStreams), c.perspective, ) c.framer = newFramer(c.connFlowController) c.receivedPackets.Init(8) c.notifyReceivedPacket = make(chan struct{}, 1) c.closeChan = make(chan struct{}, 1) c.sendingScheduled = make(chan struct{}, 1) c.handshakeCompleteChan = make(chan struct{}) now := monotime.Now() c.lastPacketReceivedTime = now c.creationTime = now c.receivedPacketHandler = *ackhandler.NewReceivedPacketHandler(c.logger) c.datagramQueue = newDatagramQueue(c.scheduleSending, c.logger) c.connState.Version = c.version } // run the connection main loop func (c *Conn) run() (err error) { defer func() { c.ctxCancel(err) }() defer func() { // drain queued packets that will never be processed c.receivedPacketMx.Lock() defer c.receivedPacketMx.Unlock() for !c.receivedPackets.Empty() { p := c.receivedPackets.PopFront() p.buffer.Decrement() p.buffer.MaybeRelease() } }() c.timer = time.NewTimer(monotime.Until(c.idleTimeoutStartTime().Add(c.config.HandshakeIdleTimeout))) if err := c.cryptoStreamHandler.StartHandshake(c.ctx); err != nil { return err } if err := c.handleHandshakeEvents(monotime.Now()); err != nil { return err } go func() { if err := c.sendQueue.Run(); err != nil { c.destroyImpl(err) } }() if c.perspective == protocol.PerspectiveClient { c.scheduleSending() // so the ClientHello actually gets sent } var sendQueueAvailable <-chan struct{} runLoop: for { if c.framer.QueuedTooManyControlFrames() { c.setCloseError(&closeError{err: &qerr.TransportError{ErrorCode: InternalError}}) break runLoop } // Close immediately if requested select { case <-c.closeChan: break runLoop default: } // no need to set a timer if we can send packets immediately if c.pacingDeadline != deadlineSendImmediately { c.maybeResetTimer() } // 1st: handle undecryptable packets, if any. // This can only occur before completion of the handshake. if len(c.undecryptablePacketsToProcess) > 0 { var processedUndecryptablePacket bool queue := c.undecryptablePacketsToProcess c.undecryptablePacketsToProcess = nil for _, p := range queue { processed, err := c.handleOnePacket(p.receivedPacket, p.datagramID) if err != nil { c.setCloseError(&closeError{err: err}) break runLoop } if processed { processedUndecryptablePacket = true } } if processedUndecryptablePacket { // if we processed any undecryptable packets, jump to the resetting of the timers directly continue } } // 2nd: receive packets. processed, err := c.handlePackets() // don't check receivedPackets.Len() in the run loop to avoid locking the mutex if err != nil { c.setCloseError(&closeError{err: err}) break runLoop } // We don't need to wait for new events if: // * we processed packets: we probably need to send an ACK, and potentially more data // * the pacer allows us to send more packets immediately shouldProceedImmediately := sendQueueAvailable == nil && (processed || c.pacingDeadline.Equal(deadlineSendImmediately)) if !shouldProceedImmediately { // 3rd: wait for something to happen: // * closing of the connection // * timer firing // * sending scheduled // * send queue available // * received packets select { case <-c.closeChan: break runLoop case <-c.timer.C: case <-c.sendingScheduled: case <-sendQueueAvailable: case <-c.notifyReceivedPacket: wasProcessed, err := c.handlePackets() if err != nil { c.setCloseError(&closeError{err: err}) break runLoop } // if we processed any undecryptable packets, jump to the resetting of the timers directly if !wasProcessed { continue } } } // Check for loss detection timeout. // This could cause packets to be declared lost, and retransmissions to be enqueued. now := monotime.Now() if timeout := c.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && !timeout.After(now) { if err := c.sentPacketHandler.OnLossDetectionTimeout(now); err != nil { c.setCloseError(&closeError{err: err}) break runLoop } } if keepAliveTime := c.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) { // send a PING frame since there is no activity in the connection c.logger.Debugf("Sending a keep-alive PING to keep the connection alive.") c.framer.QueueControlFrame(&wire.PingFrame{}) c.keepAlivePingSent = true } else if !c.handshakeComplete && now.Sub(c.creationTime) >= c.config.handshakeTimeout() { c.destroyImpl(qerr.ErrHandshakeTimeout) break runLoop } else { idleTimeoutStartTime := c.idleTimeoutStartTime() if (!c.handshakeComplete && now.Sub(idleTimeoutStartTime) >= c.config.HandshakeIdleTimeout) || (c.handshakeComplete && !now.Before(c.nextIdleTimeoutTime())) { c.destroyImpl(qerr.ErrIdleTimeout) break runLoop } } c.connIDGenerator.RemoveRetiredConnIDs(now) if c.perspective == protocol.PerspectiveClient { pm := c.pathManagerOutgoing.Load() if pm != nil { tr, ok := pm.ShouldSwitchPath() if ok { c.switchToNewPath(tr, now) } } } if c.sendQueue.WouldBlock() { // The send queue is still busy sending out packets. Wait until there's space to enqueue new packets. sendQueueAvailable = c.sendQueue.Available() // Cancel the pacing timer, as we can't send any more packets until the send queue is available again. c.pacingDeadline = 0 c.blocked = blockModeHardBlocked continue } if c.closeErr.Load() != nil { break runLoop } c.blocked = blockModeNone // sending might set it back to true if we're congestion limited if err := c.triggerSending(now); err != nil { c.setCloseError(&closeError{err: err}) break runLoop } if c.sendQueue.WouldBlock() { // The send queue is still busy sending out packets. Wait until there's space to enqueue new packets. sendQueueAvailable = c.sendQueue.Available() // Cancel the pacing timer, as we can't send any more packets until the send queue is available again. c.pacingDeadline = 0 c.blocked = blockModeHardBlocked } else { sendQueueAvailable = nil } } closeErr := c.closeErr.Load() c.cryptoStreamHandler.Close() c.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE c.handleCloseError(closeErr) if c.qlogger != nil { if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) { c.qlogger.Close() } } c.logger.Infof("Connection %s closed.", c.logID) c.timer.Stop() return closeErr.err } // blocks until the early connection can be used func (c *Conn) earlyConnReady() <-chan struct{} { return c.earlyConnReadyChan } // Context returns a context that is cancelled when the connection is closed. // The cancellation cause is set to the error that caused the connection to close. func (c *Conn) Context() context.Context { return c.ctx } func (c *Conn) supportsDatagrams() bool { return c.peerParams.MaxDatagramFrameSize > 0 } // ConnectionState returns basic details about the QUIC connection. func (c *Conn) ConnectionState() ConnectionState { c.connStateMutex.Lock() defer c.connStateMutex.Unlock() cs := c.cryptoStreamHandler.ConnectionState() c.connState.TLS = cs.ConnectionState c.connState.Used0RTT = cs.Used0RTT if c.peerParams != nil { c.connState.SupportsDatagrams.Remote = c.supportsDatagrams() c.connState.SupportsStreamResetPartialDelivery.Remote = c.peerParams.EnableResetStreamAt } c.connState.SupportsDatagrams.Local = c.config.EnableDatagrams c.connState.SupportsStreamResetPartialDelivery.Local = c.config.EnableStreamResetPartialDelivery c.connState.GSO = c.conn.capabilities().GSO return c.connState } // ConnectionStats contains statistics about the QUIC connection type ConnectionStats struct { // MinRTT is the estimate of the minimum RTT observed on the active network // path. MinRTT time.Duration // LatestRTT is the last RTT sample observed on the active network path. LatestRTT time.Duration // SmoothedRTT is an exponentially weighted moving average of an endpoint's // RTT samples. See https://www.rfc-editor.org/rfc/rfc9002#section-5.3 SmoothedRTT time.Duration // MeanDeviation estimates the variation in the RTT samples using a mean // variation. See https://www.rfc-editor.org/rfc/rfc9002#section-5.3 MeanDeviation time.Duration // BytesSent is the number of bytes sent on the underlying connection, // including retransmissions. Does not include UDP or any other outer // framing. BytesSent uint64 // PacketsSent is the number of packets sent on the underlying connection, // including those that are determined to have been lost. PacketsSent uint64 // BytesReceived is the number of total bytes received on the underlying // connection, including duplicate data for streams. Does not include UDP or // any other outer framing. BytesReceived uint64 // PacketsReceived is the number of total packets received on the underlying // connection, including packets that were not processable. PacketsReceived uint64 // BytesLost is the number of bytes lost on the underlying connection (does // not monotonically increase, because packets that are declared lost can // subsequently be received). Does not include UDP or any other outer // framing. BytesLost uint64 // PacketsLost is the number of packets lost on the underlying connection // (does not monotonically increase, because packets that are declared lost // can subsequently be received). PacketsLost uint64 } func (c *Conn) ConnectionStats() ConnectionStats { return ConnectionStats{ MinRTT: c.rttStats.MinRTT(), LatestRTT: c.rttStats.LatestRTT(), SmoothedRTT: c.rttStats.SmoothedRTT(), MeanDeviation: c.rttStats.MeanDeviation(), BytesSent: c.connStats.BytesSent.Load(), PacketsSent: c.connStats.PacketsSent.Load(), BytesReceived: c.connStats.BytesReceived.Load(), PacketsReceived: c.connStats.PacketsReceived.Load(), BytesLost: c.connStats.BytesLost.Load(), PacketsLost: c.connStats.PacketsLost.Load(), } } // Time when the connection should time out func (c *Conn) nextIdleTimeoutTime() monotime.Time { idleTimeout := max(c.idleTimeout, c.rttStats.PTO(true)*3) return c.idleTimeoutStartTime().Add(idleTimeout) } // Time when the next keep-alive packet should be sent. // It returns a zero time if no keep-alive should be sent. func (c *Conn) nextKeepAliveTime() monotime.Time { if c.config.KeepAlivePeriod == 0 || c.keepAlivePingSent { return 0 } keepAliveInterval := max(c.keepAliveInterval, c.rttStats.PTO(true)*3/2) return c.lastPacketReceivedTime.Add(keepAliveInterval) } func (c *Conn) maybeResetTimer() { var deadline monotime.Time if !c.handshakeComplete { deadline = c.creationTime.Add(c.config.handshakeTimeout()) if t := c.idleTimeoutStartTime().Add(c.config.HandshakeIdleTimeout); t.Before(deadline) { deadline = t } } else { // A keep-alive packet is ack-eliciting, so it can only be sent if the connection is // neither congestion limited nor hard-blocked. if c.blocked != blockModeNone { deadline = c.nextIdleTimeoutTime() } else { if keepAliveTime := c.nextKeepAliveTime(); !keepAliveTime.IsZero() { deadline = keepAliveTime } else { deadline = c.nextIdleTimeoutTime() } } } // If the connection is hard-blocked, we can't even send acknowledgments, // nor can we send PTO probe packets. if c.blocked == blockModeHardBlocked { c.timer.Reset(monotime.Until(deadline)) return } if t := c.receivedPacketHandler.GetAlarmTimeout(); !t.IsZero() && t.Before(deadline) { deadline = t } if t := c.sentPacketHandler.GetLossDetectionTimeout(); !t.IsZero() && t.Before(deadline) { deadline = t } if c.blocked == blockModeCongestionLimited { c.timer.Reset(monotime.Until(deadline)) return } if !c.pacingDeadline.IsZero() && c.pacingDeadline.Before(deadline) { deadline = c.pacingDeadline } c.timer.Reset(monotime.Until(deadline)) } func (c *Conn) idleTimeoutStartTime() monotime.Time { startTime := c.lastPacketReceivedTime if t := c.firstAckElicitingPacketAfterIdleSentTime; !t.IsZero() && t.After(startTime) { startTime = t } return startTime } func (c *Conn) switchToNewPath(tr *Transport, now monotime.Time) { initialPacketSize := protocol.ByteCount(c.config.InitialPacketSize) c.sentPacketHandler.MigratedPath(now, initialPacketSize) maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize) if c.peerParams.MaxUDPPayloadSize > 0 && c.peerParams.MaxUDPPayloadSize < maxPacketSize { maxPacketSize = c.peerParams.MaxUDPPayloadSize } c.mtuDiscoverer.Reset(now, initialPacketSize, maxPacketSize) c.conn = newSendConn(tr.conn, c.conn.RemoteAddr(), packetInfo{}, utils.DefaultLogger) // TODO: find a better way c.sendQueue.Close() c.sendQueue = newSendQueue(c.conn) go func() { if err := c.sendQueue.Run(); err != nil { c.destroyImpl(err) } }() } func (c *Conn) handleHandshakeComplete(now monotime.Time) error { defer close(c.handshakeCompleteChan) // Once the handshake completes, we have derived 1-RTT keys. // There's no point in queueing undecryptable packets for later decryption anymore. c.undecryptablePackets = nil c.connIDManager.SetHandshakeComplete() c.connIDGenerator.SetHandshakeComplete(now.Add(3 * c.rttStats.PTO(false))) if c.qlogger != nil { c.qlogger.RecordEvent(qlog.ALPNInformation{ ChosenALPN: c.cryptoStreamHandler.ConnectionState().NegotiatedProtocol, }) } // The server applies transport parameters right away, but the client side has to wait for handshake completion. // During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets. if c.perspective == protocol.PerspectiveClient { c.applyTransportParameters() return nil } // All these only apply to the server side. if err := c.handleHandshakeConfirmed(now); err != nil { return err } ticket, err := c.cryptoStreamHandler.GetSessionTicket() if err != nil { return err } if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled c.oneRTTStream.Write(ticket) for c.oneRTTStream.HasData() { if cf := c.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil { c.queueControlFrame(cf) } } } token, err := c.tokenGenerator.NewToken(c.conn.RemoteAddr(), c.rttStats.SmoothedRTT()) if err != nil { return err } c.queueControlFrame(&wire.NewTokenFrame{Token: token}) c.queueControlFrame(&wire.HandshakeDoneFrame{}) return nil } func (c *Conn) handleHandshakeConfirmed(now monotime.Time) error { // Drop initial keys. // On the client side, this should have happened when sending the first Handshake packet, // but this is not guaranteed if the server misbehaves. // See CVE-2025-59530 for more details. if err := c.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil { return err } if err := c.dropEncryptionLevel(protocol.EncryptionHandshake, now); err != nil { return err } c.handshakeConfirmed = true c.cryptoStreamHandler.SetHandshakeConfirmed() if !c.config.DisablePathMTUDiscovery && c.conn.capabilities().DF { c.mtuDiscoverer.Start(now) } return nil } const maxPacketsToProcess = 32 func (c *Conn) handlePackets() (wasProcessed bool, _ error) { // Process packets from the receivedPackets queue. // Limit the number of packets to process to maxPacketsToProcess, // so we eventually get a chance to send out an ACK when receiving a lot of packets. c.receivedPacketMx.Lock() if c.receivedPackets.Empty() { c.receivedPacketMx.Unlock() return false, nil } var hasMorePackets bool for range maxPacketsToProcess { p := c.receivedPackets.PopFront() c.receivedPacketMx.Unlock() var datagramID qlog.DatagramID if c.qlogger != nil && wire.IsLongHeaderPacket(p.data[0]) { datagramID = qlog.CalculateDatagramID(p.data) } processed, err := c.handleOnePacket(p, datagramID) if err != nil { return false, err } if processed { wasProcessed = true } c.receivedPacketMx.Lock() hasMorePackets = !c.receivedPackets.Empty() if !hasMorePackets { break } // Prioritize sending of new CRYPTO data. // This is especially relevant when processing 0-RTT packets. if !c.handshakeComplete && (c.initialStream.HasData() || c.handshakeStream.HasData()) { break } } c.receivedPacketMx.Unlock() if hasMorePackets { select { case c.notifyReceivedPacket <- struct{}{}: default: } } return wasProcessed, nil } func (c *Conn) handleOnePacket(rp receivedPacket, datagramID qlog.DatagramID) (wasProcessed bool, _ error) { c.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime) if wire.IsVersionNegotiationPacket(rp.data) { return false, c.handleVersionNegotiationPacket(rp) } var counter uint8 var lastConnID protocol.ConnectionID data := rp.data p := rp for len(data) > 0 { if counter > 0 { p = *(p.Clone()) p.data = data destConnID, err := wire.ParseConnectionID(p.data, c.srcConnIDLen) if err != nil { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: len(data)}, DatagramID: datagramID, Trigger: qlog.PacketDropHeaderParseError, }) } c.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{DestConnectionID: destConnID}, Raw: qlog.RawInfo{Length: len(data)}, DatagramID: datagramID, Trigger: qlog.PacketDropUnknownConnectionID, }) } c.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) break } } if wire.IsLongHeaderPacket(p.data[0]) { hdr, packetData, rest, err := wire.ParsePacket(p.data) if err != nil { if c.qlogger != nil { if err == wire.ErrUnsupportedVersion { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{Version: hdr.Version}, Raw: qlog.RawInfo{Length: len(data)}, DatagramID: datagramID, Trigger: qlog.PacketDropUnsupportedVersion, }) } else { c.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: len(data)}, DatagramID: datagramID, Trigger: qlog.PacketDropHeaderParseError, }) } } c.logger.Debugf("error parsing packet: %s", err) break } lastConnID = hdr.DestConnectionID if hdr.Version != c.version { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: len(data)}, DatagramID: datagramID, Trigger: qlog.PacketDropUnexpectedVersion, }) } c.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, c.version) break } if counter > 0 { p.buffer.Split() } counter++ // only log if this actually a coalesced packet if c.logger.Debug() && (counter > 1 || len(rest) > 0) { c.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) } p.data = packetData processed, err := c.handleLongHeaderPacket(p, hdr, datagramID) if err != nil { return false, err } if processed { wasProcessed = true } data = rest } else { if counter > 0 { p.buffer.Split() } processed, err := c.handleShortHeaderPacket(p, counter > 0, datagramID) if err != nil { return false, err } if processed { wasProcessed = true } break } } p.buffer.MaybeRelease() c.blocked = blockModeNone return wasProcessed, nil } func (c *Conn) handleShortHeaderPacket( p receivedPacket, isCoalesced bool, datagramID qlog.DatagramID, // only for logging ) (wasProcessed bool, _ error) { var wasQueued bool defer func() { // Put back the packet buffer if the packet wasn't queued for later decryption. if !wasQueued { p.buffer.Decrement() } }() destConnID, err := wire.ParseConnectionID(p.data, c.srcConnIDLen) if err != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType1RTT, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: len(p.data)}, DatagramID: datagramID, Trigger: qlog.PacketDropHeaderParseError, }) return false, nil } pn, pnLen, keyPhase, data, err := c.unpacker.UnpackShortHeader(p.rcvTime, p.data) if err != nil { // Stateless reset packets (see RFC 9000, section 10.3): // * fill the entire UDP datagram (i.e. they cannot be part of a coalesced packet) // * are short header packets (first bit is 0) // * have the QUIC bit set (second bit is 1) // * are at least 21 bytes long if !isCoalesced && len(p.data) >= protocol.MinReceivedStatelessResetSize && p.data[0]&0b11000000 == 0b01000000 { token := protocol.StatelessResetToken(p.data[len(p.data)-16:]) if c.connIDManager.IsActiveStatelessResetToken(token) { return false, &StatelessResetError{} } } wasQueued, err = c.handleUnpackError(err, p, qlog.PacketType1RTT, datagramID) return false, err } c.largestRcvdAppData = max(c.largestRcvdAppData, pn) if c.logger.Debug() { c.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID) wire.LogShortHeader(c.logger, destConnID, pn, pnLen, keyPhase) } if c.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { c.logger.Debugf("Dropping (potentially) duplicate packet.") if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType1RTT, PacketNumber: pn, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropDuplicate, }) } return false, nil } var log func([]qlog.Frame) if c.qlogger != nil { log = func(frames []qlog.Frame) { c.qlogger.RecordEvent(qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType1RTT, DestConnectionID: destConnID, PacketNumber: pn, KeyPhaseBit: keyPhase, }, Raw: qlog.RawInfo{ Length: int(p.Size()), PayloadLength: int(p.Size() - wire.ShortHeaderLen(destConnID, pnLen)), }, DatagramID: datagramID, Frames: frames, ECN: toQlogECN(p.ecn), }) } } isNonProbing, pathChallenge, err := c.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log) if err != nil { return false, err } // In RFC 9000, only the client can migrate between paths. if c.perspective == protocol.PerspectiveClient { return true, nil } if addrsEqual(p.remoteAddr, c.RemoteAddr()) { return true, nil } var shouldSwitchPath bool if c.pathManager == nil { c.pathManager = newPathManager( c.connIDManager.GetConnIDForPath, c.connIDManager.RetireConnIDForPath, c.logger, ) } destConnID, frames, shouldSwitchPath := c.pathManager.HandlePacket(p.remoteAddr, p.rcvTime, pathChallenge, isNonProbing) if len(frames) > 0 { probe, buf, err := c.packer.PackPathProbePacket(destConnID, frames, c.version) if err != nil { return true, err } c.logger.Debugf("sending path probe packet to %s", p.remoteAddr) c.logShortHeaderPacketWithDatagramID(probe, protocol.ECNNon, buf.Len(), false, datagramID) c.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime) c.sendQueue.SendProbe(buf, p.remoteAddr) } // We only switch paths in response to the highest-numbered non-probing packet, // see section 9.3 of RFC 9000. if !shouldSwitchPath || pn != c.largestRcvdAppData { return true, nil } c.pathManager.SwitchToPath(p.remoteAddr) c.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(c.config.InitialPacketSize)) maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize) if c.peerParams.MaxUDPPayloadSize > 0 && c.peerParams.MaxUDPPayloadSize < maxPacketSize { maxPacketSize = c.peerParams.MaxUDPPayloadSize } c.mtuDiscoverer.Reset( p.rcvTime, protocol.ByteCount(c.config.InitialPacketSize), maxPacketSize, ) c.conn.ChangeRemoteAddr(p.remoteAddr, p.info) return true, nil } func (c *Conn) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header, datagramID qlog.DatagramID) (wasProcessed bool, _ error) { var wasQueued bool defer func() { // Put back the packet buffer if the packet wasn't queued for later decryption. if !wasQueued { p.buffer.Decrement() } }() if hdr.Type == protocol.PacketTypeRetry { return c.handleRetryPacket(hdr, p.data, p.rcvTime), nil } // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if c.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != c.handshakeDestConnID { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropUnknownConnectionID, }) } c.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, c.handshakeDestConnID) return false, nil } // drop 0-RTT packets, if we are a client if c.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false, nil } packet, err := c.unpacker.UnpackLongHeader(hdr, p.data) if err != nil { wasQueued, err = c.handleUnpackError(err, p, toQlogPacketType(hdr.Type), datagramID) return false, err } if c.logger.Debug() { c.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.hdr.PacketNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel) packet.hdr.Log(c.logger) } if pn := packet.hdr.PacketNumber; c.receivedPacketHandler.IsPotentiallyDuplicate(pn, packet.encryptionLevel) { c.logger.Debugf("Dropping (potentially) duplicate packet.") if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: toQlogPacketType(packet.hdr.Type), DestConnectionID: hdr.DestConnectionID, SrcConnectionID: hdr.SrcConnectionID, PacketNumber: pn, Version: packet.hdr.Version, }, Raw: qlog.RawInfo{Length: int(p.Size()), PayloadLength: int(packet.hdr.Length)}, DatagramID: datagramID, Trigger: qlog.PacketDropDuplicate, }) } return false, nil } if err := c.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, datagramID, p.Size()); err != nil { return false, err } return true, nil } func (c *Conn) handleUnpackError(err error, p receivedPacket, pt qlog.PacketType, datagramID qlog.DatagramID) (wasQueued bool, _ error) { switch err { case handshake.ErrKeysDropped: if c.qlogger != nil { connID, _ := wire.ParseConnectionID(p.data, c.srcConnIDLen) c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: pt, DestConnectionID: connID, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropKeyUnavailable, }) } c.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) return false, nil case handshake.ErrKeysNotYetAvailable: // Sealer for this encryption level not yet available. // Try again later. c.tryQueueingUndecryptablePacket(p, pt, datagramID) return true, nil case wire.ErrInvalidReservedBits: return false, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: err.Error(), } case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. if c.qlogger != nil { connID, _ := wire.ParseConnectionID(p.data, c.srcConnIDLen) c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: pt, DestConnectionID: connID, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropPayloadDecryptError, }) } c.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) return false, nil default: var headerErr *headerParseError if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. if c.qlogger != nil { connID, _ := wire.ParseConnectionID(p.data, c.srcConnIDLen) c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: pt, DestConnectionID: connID, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropHeaderParseError, }) } c.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) return false, nil } // This is an error returned by the AEAD (other than ErrDecryptionFailed). // For example, a PROTOCOL_VIOLATION due to key updates. return false, err } } func (c *Conn) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime monotime.Time) bool /* was this a valid Retry */ { if c.perspective == protocol.PerspectiveServer { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: hdr.SrcConnectionID, DestConnectionID: hdr.DestConnectionID, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: len(data)}, Trigger: qlog.PacketDropUnexpectedPacket, }) } c.logger.Debugf("Ignoring Retry.") return false } if c.receivedFirstPacket { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: hdr.SrcConnectionID, DestConnectionID: hdr.DestConnectionID, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: len(data)}, Trigger: qlog.PacketDropUnexpectedPacket, }) } c.logger.Debugf("Ignoring Retry, since we already received a packet.") return false } destConnID := c.connIDManager.Get() if hdr.SrcConnectionID == destConnID { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: hdr.SrcConnectionID, DestConnectionID: hdr.DestConnectionID, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: len(data)}, Trigger: qlog.PacketDropUnexpectedPacket, }) } c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") return false } // If a token is already set, this means that we already received a Retry from the server. // Ignore this Retry packet. if c.receivedRetry { c.logger.Debugf("Ignoring Retry, since a Retry was already received.") return false } tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: hdr.SrcConnectionID, DestConnectionID: hdr.DestConnectionID, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: len(data)}, Trigger: qlog.PacketDropPayloadDecryptError, }) } c.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") return false } newDestConnID := hdr.SrcConnectionID c.receivedRetry = true c.sentPacketHandler.ResetForRetry(rcvTime) c.handshakeDestConnID = newDestConnID c.retrySrcConnID = &newDestConnID c.cryptoStreamHandler.ChangeConnectionID(newDestConnID) c.packer.SetToken(hdr.Token) c.connIDManager.ChangeInitialConnID(newDestConnID) if c.logger.Debug() { c.logger.Debugf("<- Received Retry:") (&wire.ExtendedHeader{Header: *hdr}).Log(c.logger) c.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) } if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, DestConnectionID: destConnID, SrcConnectionID: newDestConnID, Version: hdr.Version, Token: &qlog.Token{Raw: hdr.Token}, }, Raw: qlog.RawInfo{Length: len(data)}, }) } c.scheduleSending() return true } func (c *Conn) handleVersionNegotiationPacket(p receivedPacket) error { if c.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets c.receivedFirstPacket || c.versionNegotiated { // ignore delayed / duplicated version negotiation packets if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return nil } src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropHeaderParseError, }) } c.logger.Debugf("Error parsing Version Negotiation packet: %s", err) return nil } if slices.Contains(supportedVersions, c.version) { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedVersion, }) } // The Version Negotiation packet contains the version that we offered. // This might be a packet sent by an attacker, or it was corrupted. return nil } c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) if c.qlogger != nil { c.qlogger.RecordEvent(qlog.VersionNegotiationReceived{ Header: qlog.PacketHeaderVersionNegotiation{ DestConnectionID: dest, SrcConnectionID: src, }, SupportedVersions: supportedVersions, }) } newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, supportedVersions) if !ok { c.destroyImpl(&VersionNegotiationError{ Ours: c.config.Versions, Theirs: supportedVersions, }) c.logger.Infof("No compatible QUIC version found.") return nil } if c.qlogger != nil { c.qlogger.RecordEvent(qlog.VersionInformation{ ChosenVersion: newVersion, ClientVersions: c.config.Versions, ServerVersions: supportedVersions, }) } c.logger.Infof("Switching to QUIC version %s.", newVersion) nextPN, _ := c.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) return &errCloseForRecreating{ nextPacketNumber: nextPN, nextVersion: newVersion, } } func (c *Conn) handleUnpackedLongHeaderPacket( packet *unpackedPacket, ecn protocol.ECN, rcvTime monotime.Time, datagramID qlog.DatagramID, // only for logging packetSize protocol.ByteCount, // only for logging ) error { if !c.receivedFirstPacket { c.receivedFirstPacket = true if !c.versionNegotiated && c.qlogger != nil { var clientVersions, serverVersions []Version switch c.perspective { case protocol.PerspectiveClient: clientVersions = c.config.Versions case protocol.PerspectiveServer: serverVersions = c.config.Versions } c.qlogger.RecordEvent(qlog.VersionInformation{ ChosenVersion: c.version, ClientVersions: clientVersions, ServerVersions: serverVersions, }) } // The server can change the source connection ID with the first Handshake packet. if c.perspective == protocol.PerspectiveClient && packet.hdr.SrcConnectionID != c.handshakeDestConnID { cid := packet.hdr.SrcConnectionID c.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) c.handshakeDestConnID = cid c.connIDManager.ChangeInitialConnID(cid) } // We create the connection as soon as we receive the first packet from the client. // We do that before authenticating the packet. // That means that if the source connection ID was corrupted, // we might have created a connection with an incorrect source connection ID. // Once we authenticate the first packet, we need to update it. if c.perspective == protocol.PerspectiveServer { if packet.hdr.SrcConnectionID != c.handshakeDestConnID { c.handshakeDestConnID = packet.hdr.SrcConnectionID c.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } if c.qlogger != nil { var srcAddr, destAddr *net.UDPAddr if addr, ok := c.conn.LocalAddr().(*net.UDPAddr); ok { srcAddr = addr } if addr, ok := c.conn.RemoteAddr().(*net.UDPAddr); ok { destAddr = addr } c.qlogger.RecordEvent(startedConnectionEvent(srcAddr, destAddr)) } } } if c.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake && !c.droppedInitialKeys { // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // See Section 4.9.1 of RFC 9001. if err := c.dropEncryptionLevel(protocol.EncryptionInitial, rcvTime); err != nil { return err } } c.lastPacketReceivedTime = rcvTime c.firstAckElicitingPacketAfterIdleSentTime = 0 c.keepAlivePingSent = false if packet.hdr.Type == protocol.PacketType0RTT { c.largestRcvdAppData = max(c.largestRcvdAppData, packet.hdr.PacketNumber) } var log func([]qlog.Frame) if c.qlogger != nil { log = func(frames []qlog.Frame) { var token *qlog.Token if len(packet.hdr.Token) > 0 { token = &qlog.Token{Raw: packet.hdr.Token} } c.qlogger.RecordEvent(qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: toQlogPacketType(packet.hdr.Type), DestConnectionID: packet.hdr.DestConnectionID, SrcConnectionID: packet.hdr.SrcConnectionID, PacketNumber: packet.hdr.PacketNumber, Version: packet.hdr.Version, Token: token, }, Raw: qlog.RawInfo{ Length: int(packetSize), PayloadLength: int(packet.hdr.Length), }, DatagramID: datagramID, Frames: frames, ECN: toQlogECN(ecn), }) } } isAckEliciting, _, _, err := c.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime) if err != nil { return err } c.sentPacketHandler.ReceivedPacket(packet.encryptionLevel, rcvTime) return c.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) } func (c *Conn) handleUnpackedShortHeaderPacket( destConnID protocol.ConnectionID, pn protocol.PacketNumber, data []byte, ecn protocol.ECN, rcvTime monotime.Time, log func([]qlog.Frame), ) (isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) { c.lastPacketReceivedTime = rcvTime c.firstAckElicitingPacketAfterIdleSentTime = 0 c.keepAlivePingSent = false isAckEliciting, isNonProbing, pathChallenge, err := c.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime) if err != nil { return false, nil, err } c.sentPacketHandler.ReceivedPacket(protocol.Encryption1RTT, rcvTime) if err := c.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting); err != nil { return false, nil, err } return isNonProbing, pathChallenge, nil } // handleFrames parses the frames, one after the other, and handles them. // It returns the last PATH_CHALLENGE frame contained in the packet, if any. func (c *Conn) handleFrames( data []byte, destConnID protocol.ConnectionID, encLevel protocol.EncryptionLevel, log func([]qlog.Frame), rcvTime monotime.Time, ) (isAckEliciting, isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) { // Only used for tracing. // If we're not tracing, this slice will always remain empty. var frames []qlog.Frame if log != nil { frames = make([]qlog.Frame, 0, 4) } handshakeWasComplete := c.handshakeComplete var handleErr error var skipHandling bool for len(data) > 0 { frameType, l, err := c.frameParser.ParseType(data, encLevel) if err != nil { // The frame parser skips over PADDING frames, and returns an io.EOF if the PADDING // frames were the last frames in this packet. if err == io.EOF { break } return false, false, nil, err } data = data[l:] if ackhandler.IsFrameTypeAckEliciting(frameType) { isAckEliciting = true } if !wire.IsProbingFrameType(frameType) { isNonProbing = true } // We're inlining common cases, to avoid using interfaces // Fast path: STREAM, DATAGRAM and ACK if frameType.IsStreamFrameType() { streamFrame, l, err := c.frameParser.ParseStreamFrame(frameType, data, c.version) if err != nil { return false, false, nil, err } data = data[l:] if log != nil { frames = append(frames, toQlogFrame(streamFrame)) } // an error occurred handling a previous frame, don't handle the current frame if skipHandling { continue } wire.LogFrame(c.logger, streamFrame, false) handleErr = c.streamsMap.HandleStreamFrame(streamFrame, rcvTime) } else if frameType.IsAckFrameType() { ackFrame, l, err := c.frameParser.ParseAckFrame(frameType, data, encLevel, c.version) if err != nil { return false, false, nil, err } data = data[l:] if log != nil { frames = append(frames, toQlogFrame(ackFrame)) } // an error occurred handling a previous frame, don't handle the current frame if skipHandling { continue } wire.LogFrame(c.logger, ackFrame, false) handleErr = c.handleAckFrame(ackFrame, encLevel, rcvTime) } else if frameType.IsDatagramFrameType() { datagramFrame, l, err := c.frameParser.ParseDatagramFrame(frameType, data, c.version) if err != nil { return false, false, nil, err } data = data[l:] if log != nil { frames = append(frames, toQlogFrame(datagramFrame)) } // an error occurred handling a previous frame, don't handle the current frame if skipHandling { continue } wire.LogFrame(c.logger, datagramFrame, false) handleErr = c.handleDatagramFrame(datagramFrame) } else { frame, l, err := c.frameParser.ParseLessCommonFrame(frameType, data, c.version) if err != nil { return false, false, nil, err } data = data[l:] if log != nil { frames = append(frames, toQlogFrame(frame)) } // an error occurred handling a previous frame, don't handle the current frame if skipHandling { continue } pc, err := c.handleFrame(frame, encLevel, destConnID, rcvTime) if pc != nil { pathChallenge = pc } handleErr = err } if handleErr != nil { // if we're logging, we need to keep parsing (but not handling) all frames skipHandling = true if log == nil { return false, false, nil, handleErr } } } if log != nil { log(frames) if handleErr != nil { return false, false, nil, handleErr } } // Handle completion of the handshake after processing all the frames. // This ensures that we correctly handle the following case on the server side: // We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake, // and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame. if !handshakeWasComplete && c.handshakeComplete { if err := c.handleHandshakeComplete(rcvTime); err != nil { return false, false, nil, err } } return } func (c *Conn) handleFrame( f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID, rcvTime monotime.Time, ) (pathChallenge *wire.PathChallengeFrame, _ error) { var err error wire.LogFrame(c.logger, f, false) switch frame := f.(type) { case *wire.CryptoFrame: err = c.handleCryptoFrame(frame, encLevel, rcvTime) case *wire.ConnectionCloseFrame: err = c.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: err = c.streamsMap.HandleResetStreamFrame(frame, rcvTime) case *wire.MaxDataFrame: c.connFlowController.UpdateSendWindow(frame.MaximumData) case *wire.MaxStreamDataFrame: err = c.streamsMap.HandleMaxStreamDataFrame(frame) case *wire.MaxStreamsFrame: c.streamsMap.HandleMaxStreamsFrame(frame) case *wire.DataBlockedFrame: case *wire.StreamDataBlockedFrame: err = c.streamsMap.HandleStreamDataBlockedFrame(frame) case *wire.StreamsBlockedFrame: case *wire.StopSendingFrame: err = c.streamsMap.HandleStopSendingFrame(frame) case *wire.PingFrame: case *wire.PathChallengeFrame: c.handlePathChallengeFrame(frame) pathChallenge = frame case *wire.PathResponseFrame: err = c.handlePathResponseFrame(frame) case *wire.NewTokenFrame: err = c.handleNewTokenFrame(frame) case *wire.NewConnectionIDFrame: err = c.connIDManager.Add(frame) case *wire.RetireConnectionIDFrame: err = c.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*c.rttStats.PTO(false))) case *wire.HandshakeDoneFrame: err = c.handleHandshakeDoneFrame(rcvTime) default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } return pathChallenge, err } // handlePacket is called by the server with a new packet func (c *Conn) handlePacket(p receivedPacket) { c.receivedPacketMx.Lock() // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxConnUnprocessedPackets if c.receivedPackets.Len() >= protocol.MaxConnUnprocessedPackets { if c.qlogger != nil { var datagramID qlog.DatagramID if wire.IsLongHeaderPacket(p.data[0]) { datagramID = qlog.CalculateDatagramID(p.data) } c.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropDOSPrevention, }) } c.receivedPacketMx.Unlock() return } c.receivedPackets.PushBack(p) c.receivedPacketMx.Unlock() select { case c.notifyReceivedPacket <- struct{}{}: default: } } func (c *Conn) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) error { if frame.IsApplicationError { return &qerr.ApplicationError{ Remote: true, ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode), ErrorMessage: frame.ReasonPhrase, } } return &qerr.TransportError{ Remote: true, ErrorCode: qerr.TransportErrorCode(frame.ErrorCode), FrameType: frame.FrameType, ErrorMessage: frame.ReasonPhrase, } } func (c *Conn) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime monotime.Time) error { if err := c.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil { return err } for { data := c.cryptoStreamManager.GetCryptoData(encLevel) if data == nil { break } if err := c.cryptoStreamHandler.HandleMessage(data, encLevel); err != nil { return err } } return c.handleHandshakeEvents(rcvTime) } func (c *Conn) handleHandshakeEvents(now monotime.Time) error { for { ev := c.cryptoStreamHandler.NextEvent() var err error switch ev.Kind { case handshake.EventNoEvent: return nil case handshake.EventHandshakeComplete: // Don't call handleHandshakeComplete yet. // It's advantageous to process ACK frames that might be serialized after the CRYPTO frame first. c.handshakeComplete = true case handshake.EventReceivedTransportParameters: err = c.handleTransportParameters(ev.TransportParameters) case handshake.EventRestoredTransportParameters: c.restoreTransportParameters(ev.TransportParameters) close(c.earlyConnReadyChan) case handshake.EventReceivedReadKeys: // queue all previously undecryptable packets c.undecryptablePacketsToProcess = append(c.undecryptablePacketsToProcess, c.undecryptablePackets...) c.undecryptablePackets = nil case handshake.EventDiscard0RTTKeys: err = c.dropEncryptionLevel(protocol.Encryption0RTT, now) case handshake.EventWriteInitialData: _, err = c.initialStream.Write(ev.Data) case handshake.EventWriteHandshakeData: _, err = c.handshakeStream.Write(ev.Data) } if err != nil { return err } } } func (c *Conn) handlePathChallengeFrame(f *wire.PathChallengeFrame) { if c.perspective == protocol.PerspectiveClient { c.queueControlFrame(&wire.PathResponseFrame{Data: f.Data}) } } func (c *Conn) handlePathResponseFrame(f *wire.PathResponseFrame) error { switch c.perspective { case protocol.PerspectiveClient: return c.handlePathResponseFrameClient(f) case protocol.PerspectiveServer: return c.handlePathResponseFrameServer(f) default: panic("unreachable") } } func (c *Conn) handlePathResponseFrameClient(f *wire.PathResponseFrame) error { pm := c.pathManagerOutgoing.Load() if pm == nil { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "unexpected PATH_RESPONSE frame", } } pm.HandlePathResponseFrame(f) return nil } func (c *Conn) handlePathResponseFrameServer(f *wire.PathResponseFrame) error { if c.pathManager == nil { // since we didn't send PATH_CHALLENGEs yet, we don't expect PATH_RESPONSEs return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "unexpected PATH_RESPONSE frame", } } c.pathManager.HandlePathResponseFrame(f) return nil } func (c *Conn) handleNewTokenFrame(frame *wire.NewTokenFrame) error { if c.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received NEW_TOKEN frame from the client", } } if c.config.TokenStore != nil { c.config.TokenStore.Put(c.tokenStoreKey, &ClientToken{data: frame.Token, rtt: c.rttStats.SmoothedRTT()}) } return nil } func (c *Conn) handleHandshakeDoneFrame(rcvTime monotime.Time) error { if c.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received a HANDSHAKE_DONE frame", } } if !c.handshakeConfirmed { return c.handleHandshakeConfirmed(rcvTime) } return nil } func (c *Conn) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime monotime.Time) error { acked1RTTPacket, err := c.sentPacketHandler.ReceivedAck(frame, encLevel, c.lastPacketReceivedTime) if err != nil { return err } if !acked1RTTPacket { return nil } // On the client side: If the packet acknowledged a 1-RTT packet, this confirms the handshake. // This is only possible if the ACK was sent in a 1-RTT packet. // This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001. if c.perspective == protocol.PerspectiveClient && !c.handshakeConfirmed { if err := c.handleHandshakeConfirmed(rcvTime); err != nil { return err } } // If one of the acknowledged packets was a Path MTU probe packet, this might have increased the Path MTU estimate. if c.mtuDiscoverer != nil { if mtu := c.mtuDiscoverer.CurrentSize(); mtu > protocol.ByteCount(c.currentMTUEstimate.Load()) { c.currentMTUEstimate.Store(uint32(mtu)) c.sentPacketHandler.SetMaxDatagramSize(mtu) } } return c.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } func (c *Conn) handleDatagramFrame(f *wire.DatagramFrame) error { if f.Length(c.version) > wire.MaxDatagramSize { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "DATAGRAM frame too large", } } c.datagramQueue.HandleDatagramFrame(f) return nil } func (c *Conn) setCloseError(e *closeError) { c.closeErr.CompareAndSwap(nil, e) select { case c.closeChan <- struct{}{}: default: } } // closeLocal closes the connection and send a CONNECTION_CLOSE containing the error func (c *Conn) closeLocal(e error) { c.setCloseError(&closeError{err: e, immediate: false}) } // destroy closes the connection without sending the error on the wire func (c *Conn) destroy(e error) { c.destroyImpl(e) <-c.ctx.Done() } func (c *Conn) destroyImpl(e error) { c.setCloseError(&closeError{err: e, immediate: true}) } // CloseWithError closes the connection with an error. // The error string will be sent to the peer. func (c *Conn) CloseWithError(code ApplicationErrorCode, desc string) error { c.closeLocal(&qerr.ApplicationError{ ErrorCode: code, ErrorMessage: desc, }) <-c.ctx.Done() return nil } func (c *Conn) closeWithTransportError(code TransportErrorCode) { c.closeLocal(&qerr.TransportError{ErrorCode: code}) <-c.ctx.Done() } func (c *Conn) handleCloseError(closeErr *closeError) { if closeErr.immediate { if nerr, ok := closeErr.err.(net.Error); ok && nerr.Timeout() { c.logger.Errorf("Destroying connection: %s", closeErr.err) } else { c.logger.Errorf("Destroying connection with error: %s", closeErr.err) } } else { if closeErr.err == nil { c.logger.Infof("Closing connection.") } else { c.logger.Errorf("Closing connection with error: %s", closeErr.err) } } e := closeErr.err if e == nil { e = &qerr.ApplicationError{} } else { defer func() { closeErr.err = e }() } var ( statelessResetErr *StatelessResetError versionNegotiationErr *VersionNegotiationError recreateErr *errCloseForRecreating applicationErr *ApplicationError transportErr *TransportError ) var isRemoteClose bool var trigger qlog.ConnectionCloseTrigger var reason string var transportErrorCode *qlog.TransportErrorCode var applicationErrorCode *qlog.ApplicationErrorCode switch { case errors.Is(e, qerr.ErrIdleTimeout), errors.Is(e, qerr.ErrHandshakeTimeout): trigger = qlog.ConnectionCloseTriggerIdleTimeout case errors.As(e, &statelessResetErr): trigger = qlog.ConnectionCloseTriggerStatelessReset case errors.As(e, &versionNegotiationErr): trigger = qlog.ConnectionCloseTriggerVersionMismatch case errors.As(e, &recreateErr): case errors.As(e, &applicationErr): isRemoteClose = applicationErr.Remote reason = applicationErr.ErrorMessage applicationErrorCode = &applicationErr.ErrorCode case errors.As(e, &transportErr): isRemoteClose = transportErr.Remote reason = transportErr.ErrorMessage transportErrorCode = &transportErr.ErrorCode case closeErr.immediate: e = closeErr.err default: te := &qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: e.Error(), } e = te reason = te.ErrorMessage code := te.ErrorCode transportErrorCode = &code } c.streamsMap.CloseWithError(e) if c.datagramQueue != nil { c.datagramQueue.CloseWithError(e) } // In rare instances, the connection ID manager might switch to a new connection ID // when sending the CONNECTION_CLOSE frame. // The connection ID manager removes the active stateless reset token from the packet // handler map when it is closed, so we need to make sure that this happens last. defer c.connIDManager.Close() if c.qlogger != nil && !errors.As(e, &recreateErr) { initiator := qlog.InitiatorLocal if isRemoteClose { initiator = qlog.InitiatorRemote } c.qlogger.RecordEvent(qlog.ConnectionClosed{ Initiator: initiator, ConnectionError: transportErrorCode, ApplicationError: applicationErrorCode, Trigger: trigger, Reason: reason, }) } // If this is a remote close we're done here if isRemoteClose { c.connIDGenerator.ReplaceWithClosed(nil, 3*c.rttStats.PTO(false)) return } if closeErr.immediate { c.connIDGenerator.RemoveAll() return } // Don't send out any CONNECTION_CLOSE if this is an error that occurred // before we even sent out the first packet. if c.perspective == protocol.PerspectiveClient && !c.sentFirstPacket { c.connIDGenerator.RemoveAll() return } connClosePacket, err := c.sendConnectionClose(e) if err != nil { c.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } c.connIDGenerator.ReplaceWithClosed(connClosePacket, 3*c.rttStats.PTO(false)) } func (c *Conn) dropEncryptionLevel(encLevel protocol.EncryptionLevel, now monotime.Time) error { c.sentPacketHandler.DropPackets(encLevel, now) c.receivedPacketHandler.DropPackets(encLevel) //nolint:exhaustive // only Initial and 0-RTT need special treatment switch encLevel { case protocol.EncryptionInitial: c.droppedInitialKeys = true c.cryptoStreamHandler.DiscardInitialKeys() case protocol.Encryption0RTT: c.streamsMap.ResetFor0RTT() c.framer.Handle0RTTRejection() return c.connFlowController.Reset() } return c.cryptoStreamManager.Drop(encLevel) } // is called for the client, when restoring transport parameters saved for 0-RTT func (c *Conn) restoreTransportParameters(params *wire.TransportParameters) { if c.logger.Debug() { c.logger.Debugf("Restoring Transport Parameters: %s", params) } if c.qlogger != nil { c.qlogger.RecordEvent(qlog.ParametersSet{ Restore: true, Initiator: qlog.InitiatorRemote, SentBy: c.perspective, OriginalDestinationConnectionID: params.OriginalDestinationConnectionID, InitialSourceConnectionID: params.InitialSourceConnectionID, RetrySourceConnectionID: params.RetrySourceConnectionID, StatelessResetToken: params.StatelessResetToken, DisableActiveMigration: params.DisableActiveMigration, MaxIdleTimeout: params.MaxIdleTimeout, MaxUDPPayloadSize: params.MaxUDPPayloadSize, AckDelayExponent: params.AckDelayExponent, MaxAckDelay: params.MaxAckDelay, ActiveConnectionIDLimit: params.ActiveConnectionIDLimit, InitialMaxData: params.InitialMaxData, InitialMaxStreamDataBidiLocal: params.InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote: params.InitialMaxStreamDataBidiRemote, InitialMaxStreamDataUni: params.InitialMaxStreamDataUni, InitialMaxStreamsBidi: int64(params.MaxBidiStreamNum), InitialMaxStreamsUni: int64(params.MaxUniStreamNum), MaxDatagramFrameSize: params.MaxDatagramFrameSize, EnableResetStreamAt: params.EnableResetStreamAt, }) } c.peerParams = params c.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) c.connFlowController.UpdateSendWindow(params.InitialMaxData) c.streamsMap.HandleTransportParameters(params) } func (c *Conn) handleTransportParameters(params *wire.TransportParameters) error { if c.qlogger != nil { c.qlogTransportParameters(params, c.perspective.Opposite(), false) } if err := c.checkTransportParameters(params); err != nil { return &qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: err.Error(), } } if c.perspective == protocol.PerspectiveClient && c.peerParams != nil && c.ConnectionState().Used0RTT && !params.ValidForUpdate(c.peerParams) { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "server sent reduced limits after accepting 0-RTT data", } } c.peerParams = params // On the client side we have to wait for handshake completion. // During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets. if c.perspective == protocol.PerspectiveServer { c.applyTransportParameters() // On the server side, the early connection is ready as soon as we processed // the client's transport parameters. close(c.earlyConnReadyChan) } return nil } func (c *Conn) checkTransportParameters(params *wire.TransportParameters) error { if c.logger.Debug() { c.logger.Debugf("Processed Transport Parameters: %s", params) } // check the initial_source_connection_id if params.InitialSourceConnectionID != c.handshakeDestConnID { return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", c.handshakeDestConnID, params.InitialSourceConnectionID) } if c.perspective == protocol.PerspectiveServer { return nil } // check the original_destination_connection_id if params.OriginalDestinationConnectionID != c.origDestConnID { return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", c.origDestConnID, params.OriginalDestinationConnectionID) } if c.retrySrcConnID != nil { // a Retry was performed if params.RetrySourceConnectionID == nil { return errors.New("missing retry_source_connection_id") } if *params.RetrySourceConnectionID != *c.retrySrcConnID { return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", c.retrySrcConnID, *params.RetrySourceConnectionID) } } else if params.RetrySourceConnectionID != nil { return errors.New("received retry_source_connection_id, although no Retry was performed") } return nil } func (c *Conn) applyTransportParameters() { params := c.peerParams // Our local idle timeout will always be > 0. c.idleTimeout = c.config.MaxIdleTimeout // If the peer advertised an idle timeout, take the minimum of the values. if params.MaxIdleTimeout > 0 { c.idleTimeout = min(c.idleTimeout, params.MaxIdleTimeout) } c.keepAliveInterval = min(c.config.KeepAlivePeriod, c.idleTimeout/2) c.streamsMap.HandleTransportParameters(params) c.frameParser.SetAckDelayExponent(params.AckDelayExponent) c.connFlowController.UpdateSendWindow(params.InitialMaxData) c.rttStats.SetMaxAckDelay(params.MaxAckDelay) c.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) if params.StatelessResetToken != nil { c.connIDManager.SetStatelessResetToken(*params.StatelessResetToken) } // We don't support connection migration yet, so we don't have any use for the preferred_address. if params.PreferredAddress != nil { // Retire the connection ID. c.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken) } maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize) if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize { maxPacketSize = params.MaxUDPPayloadSize } c.mtuDiscoverer = newMTUDiscoverer( c.rttStats, protocol.ByteCount(c.config.InitialPacketSize), maxPacketSize, c.qlogger, ) } func (c *Conn) triggerSending(now monotime.Time) error { c.pacingDeadline = 0 sendMode := c.sentPacketHandler.SendMode(now) switch sendMode { case ackhandler.SendAny: return c.sendPackets(now) case ackhandler.SendNone: c.blocked = blockModeHardBlocked return nil case ackhandler.SendPacingLimited: deadline := c.sentPacketHandler.TimeUntilSend() if deadline.IsZero() { deadline = deadlineSendImmediately } c.pacingDeadline = deadline // Allow sending of an ACK if we're pacing limit. // This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) // sends enough ACKs to allow its peer to utilize the bandwidth. return c.maybeSendAckOnlyPacket(now) case ackhandler.SendAck: // We can at most send a single ACK only packet. // There will only be a new ACK after receiving new packets. // SendAck is only returned when we're congestion limited, so we don't need to set the pacing timer. c.blocked = blockModeCongestionLimited return c.maybeSendAckOnlyPacket(now) case ackhandler.SendPTOInitial, ackhandler.SendPTOHandshake, ackhandler.SendPTOAppData: if err := c.sendProbePacket(sendMode, now); err != nil { return err } if c.sendQueue.WouldBlock() { c.scheduleSending() return nil } return c.triggerSending(now) default: return fmt.Errorf("BUG: invalid send mode %d", sendMode) } } func (c *Conn) sendPackets(now monotime.Time) error { if c.perspective == protocol.PerspectiveClient && c.handshakeConfirmed { if pm := c.pathManagerOutgoing.Load(); pm != nil { connID, frame, tr, ok := pm.NextPathToProbe() if ok { probe, buf, err := c.packer.PackPathProbePacket(connID, []ackhandler.Frame{frame}, c.version) if err != nil { return err } c.logger.Debugf("sending path probe packet from %s", c.LocalAddr()) c.logShortHeaderPacket(probe, protocol.ECNNon, buf.Len()) c.registerPackedShortHeaderPacket(probe, protocol.ECNNon, now) tr.WriteTo(buf.Data, c.conn.RemoteAddr()) // There's (likely) more data to send. Loop around again. c.scheduleSending() return nil } } } // Path MTU Discovery // Can't use GSO, since we need to send a single packet that's larger than our current maximum size. // Performance-wise, this doesn't matter, since we only send a very small (<10) number of // MTU probe packets per connection. if c.handshakeConfirmed && c.mtuDiscoverer != nil && c.mtuDiscoverer.ShouldSendProbe(now) { ping, size := c.mtuDiscoverer.GetPing(now) p, buf, err := c.packer.PackMTUProbePacket(ping, size, c.version) if err != nil { return err } ecn := c.sentPacketHandler.ECNMode(true) c.logShortHeaderPacket(p, ecn, buf.Len()) c.registerPackedShortHeaderPacket(p, ecn, now) c.sendQueue.Send(buf, 0, ecn) // There's (likely) more data to send. Loop around again. c.scheduleSending() return nil } if offset := c.connFlowController.GetWindowUpdate(now); offset > 0 { c.framer.QueueControlFrame(&wire.MaxDataFrame{MaximumData: offset}) } if cf := c.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil { c.queueControlFrame(cf) } if !c.handshakeConfirmed { packet, err := c.packer.PackCoalescedPacket(false, c.maxPacketSize(), now, c.version) if err != nil || packet == nil { return err } c.sentFirstPacket = true if err := c.sendPackedCoalescedPacket(packet, c.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil { return err } //nolint:exhaustive // only need to handle pacing-related events here switch c.sentPacketHandler.SendMode(now) { case ackhandler.SendPacingLimited: c.resetPacingDeadline() case ackhandler.SendAny: c.pacingDeadline = deadlineSendImmediately } return nil } if c.conn.capabilities().GSO { return c.sendPacketsWithGSO(now) } return c.sendPacketsWithoutGSO(now) } func (c *Conn) sendPacketsWithoutGSO(now monotime.Time) error { for { buf := getPacketBuffer() ecn := c.sentPacketHandler.ECNMode(true) if _, err := c.appendOneShortHeaderPacket(buf, c.maxPacketSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil } return err } c.sendQueue.Send(buf, 0, ecn) if c.sendQueue.WouldBlock() { return nil } sendMode := c.sentPacketHandler.SendMode(now) if sendMode == ackhandler.SendPacingLimited { c.resetPacingDeadline() return nil } if sendMode != ackhandler.SendAny { return nil } // Prioritize receiving of packets over sending out more packets. c.receivedPacketMx.Lock() hasPackets := !c.receivedPackets.Empty() c.receivedPacketMx.Unlock() if hasPackets { c.pacingDeadline = deadlineSendImmediately return nil } } } func (c *Conn) sendPacketsWithGSO(now monotime.Time) error { buf := getLargePacketBuffer() maxSize := c.maxPacketSize() ecn := c.sentPacketHandler.ECNMode(true) for { var dontSendMore bool size, err := c.appendOneShortHeaderPacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err } if buf.Len() == 0 { buf.Release() return nil } dontSendMore = true } if !dontSendMore { sendMode := c.sentPacketHandler.SendMode(now) if sendMode == ackhandler.SendPacingLimited { c.resetPacingDeadline() } if sendMode != ackhandler.SendAny { dontSendMore = true } } // Don't send more packets in this batch if they require a different ECN marking than the previous ones. nextECN := c.sentPacketHandler.ECNMode(true) // Append another packet if // 1. The congestion controller and pacer allow sending more // 2. The last packet appended was a full-size packet // 3. The next packet will have the same ECN marking // 4. We still have enough space for another full-size packet in the buffer if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() { continue } c.sendQueue.Send(buf, uint16(maxSize), ecn) if dontSendMore { return nil } if c.sendQueue.WouldBlock() { return nil } // Prioritize receiving of packets over sending out more packets. c.receivedPacketMx.Lock() hasPackets := !c.receivedPackets.Empty() c.receivedPacketMx.Unlock() if hasPackets { c.pacingDeadline = deadlineSendImmediately return nil } ecn = nextECN buf = getLargePacketBuffer() } } func (c *Conn) resetPacingDeadline() { deadline := c.sentPacketHandler.TimeUntilSend() if deadline.IsZero() { deadline = deadlineSendImmediately } c.pacingDeadline = deadline } func (c *Conn) maybeSendAckOnlyPacket(now monotime.Time) error { if !c.handshakeConfirmed { ecn := c.sentPacketHandler.ECNMode(false) packet, err := c.packer.PackCoalescedPacket(true, c.maxPacketSize(), now, c.version) if err != nil { return err } if packet == nil { return nil } return c.sendPackedCoalescedPacket(packet, ecn, now) } ecn := c.sentPacketHandler.ECNMode(true) p, buf, err := c.packer.PackAckOnlyPacket(c.maxPacketSize(), now, c.version) if err != nil { if err == errNothingToPack { return nil } return err } c.logShortHeaderPacket(p, ecn, buf.Len()) c.registerPackedShortHeaderPacket(p, ecn, now) c.sendQueue.Send(buf, 0, ecn) return nil } func (c *Conn) sendProbePacket(sendMode ackhandler.SendMode, now monotime.Time) error { var encLevel protocol.EncryptionLevel //nolint:exhaustive // We only need to handle the PTO send modes here. switch sendMode { case ackhandler.SendPTOInitial: encLevel = protocol.EncryptionInitial case ackhandler.SendPTOHandshake: encLevel = protocol.EncryptionHandshake case ackhandler.SendPTOAppData: encLevel = protocol.Encryption1RTT default: return fmt.Errorf("connection BUG: unexpected send mode: %d", sendMode) } // Queue probe packets until we actually send out a packet, // or until there are no more packets to queue. var packet *coalescedPacket for packet == nil { if wasQueued := c.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { break } var err error packet, err = c.packer.PackPTOProbePacket(encLevel, c.maxPacketSize(), false, now, c.version) if err != nil { return err } } if packet == nil { var err error packet, err = c.packer.PackPTOProbePacket(encLevel, c.maxPacketSize(), true, now, c.version) if err != nil { return err } } if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet: %v", encLevel, packet) } return c.sendPackedCoalescedPacket(packet, c.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now) } // appendOneShortHeaderPacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. func (c *Conn) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now monotime.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := c.packer.AppendPacket(buf, maxSize, now, c.version) if err != nil { return 0, err } size := buf.Len() - startLen c.logShortHeaderPacket(p, ecn, size) c.registerPackedShortHeaderPacket(p, ecn, now) return size, nil } func (c *Conn) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now monotime.Time) { if p.IsPathProbePacket { c.sentPacketHandler.SentPacket( now, p.PacketNumber, protocol.InvalidPacketNumber, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket, true, ) return } if c.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { c.firstAckElicitingPacketAfterIdleSentTime = now } largestAcked := protocol.InvalidPacketNumber if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } c.sentPacketHandler.SentPacket( now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket, false, ) c.connIDManager.SentPacket() } func (c *Conn) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now monotime.Time) error { c.logCoalescedPacket(packet, ecn) for _, p := range packet.longHdrPackets { if c.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { c.firstAckElicitingPacketAfterIdleSentTime = now } largestAcked := protocol.InvalidPacketNumber if p.ack != nil { largestAcked = p.ack.LargestAcked() } c.sentPacketHandler.SentPacket( now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false, false, ) if c.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake && !c.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. if err := c.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil { return err } } } if p := packet.shortHdrPacket; p != nil { if c.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { c.firstAckElicitingPacketAfterIdleSentTime = now } largestAcked := protocol.InvalidPacketNumber if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } c.sentPacketHandler.SentPacket( now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket, false, ) } c.connIDManager.SentPacket() c.sendQueue.Send(packet.buffer, 0, ecn) return nil } func (c *Conn) sendConnectionClose(e error) ([]byte, error) { var packet *coalescedPacket var err error var transportErr *qerr.TransportError var applicationErr *qerr.ApplicationError if errors.As(e, &transportErr) { packet, err = c.packer.PackConnectionClose(transportErr, c.maxPacketSize(), c.version) } else if errors.As(e, &applicationErr) { packet, err = c.packer.PackApplicationClose(applicationErr, c.maxPacketSize(), c.version) } else { packet, err = c.packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), }, c.maxPacketSize(), c.version) } if err != nil { return nil, err } ecn := c.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()) c.logCoalescedPacket(packet, ecn) return packet.buffer.Data, c.conn.Write(packet.buffer.Data, 0, ecn) } func (c *Conn) maxPacketSize() protocol.ByteCount { if c.mtuDiscoverer == nil { // Use the configured packet size on the client side. // If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this: // Apparently the server still processed the (fully padded) Initial packet anyway. if c.perspective == protocol.PerspectiveClient { return protocol.ByteCount(c.config.InitialPacketSize) } // On the server side, there's no downside to using 1200 bytes until we received the client's transport // parameters: // * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't // need a lot of bytes for that. // * If it did, we will have processed the transport parameters and initialized the MTU discoverer. return protocol.MinInitialPacketSize } return c.mtuDiscoverer.CurrentSize() } // AcceptStream returns the next stream opened by the peer, blocking until one is available. func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { return c.streamsMap.AcceptStream(ctx) } // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. func (c *Conn) AcceptUniStream(ctx context.Context) (*ReceiveStream, error) { return c.streamsMap.AcceptUniStream(ctx) } // OpenStream opens a new bidirectional QUIC stream. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream, // or the stream has been reset or closed. // When reaching the peer's stream limit, it is not possible to open a new stream until the // peer raises the stream limit. In that case, a [StreamLimitReachedError] is returned. func (c *Conn) OpenStream() (*Stream, error) { return c.streamsMap.OpenStream() } // OpenStreamSync opens a new bidirectional QUIC stream. // It blocks until a new stream can be opened. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream, // or the stream has been reset or closed. func (c *Conn) OpenStreamSync(ctx context.Context) (*Stream, error) { return c.streamsMap.OpenStreamSync(ctx) } // OpenUniStream opens a new outgoing unidirectional QUIC stream. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream, // or the stream has been reset or closed. // When reaching the peer's stream limit, it is not possible to open a new stream until the // peer raises the stream limit. In that case, a [StreamLimitReachedError] is returned. func (c *Conn) OpenUniStream() (*SendStream, error) { return c.streamsMap.OpenUniStream() } // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. // It blocks until a new stream can be opened. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream, // or the stream has been reset or closed. func (c *Conn) OpenUniStreamSync(ctx context.Context) (*SendStream, error) { return c.streamsMap.OpenUniStreamSync(ctx) } func (c *Conn) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { initialSendWindow := c.peerParams.InitialMaxStreamDataUni if id.Type() == protocol.StreamTypeBidi { if id.InitiatedBy() == c.perspective { initialSendWindow = c.peerParams.InitialMaxStreamDataBidiRemote } else { initialSendWindow = c.peerParams.InitialMaxStreamDataBidiLocal } } return flowcontrol.NewStreamFlowController( id, c.connFlowController, protocol.ByteCount(c.config.InitialStreamReceiveWindow), protocol.ByteCount(c.config.MaxStreamReceiveWindow), initialSendWindow, c.rttStats, c.logger, ) } // scheduleSending signals that we have data for sending func (c *Conn) scheduleSending() { select { case c.sendingScheduled <- struct{}{}: default: } } // tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. // The qlogevents.PacketType is only used for logging purposes. func (c *Conn) tryQueueingUndecryptablePacket(p receivedPacket, pt qlog.PacketType, datagramID qlog.DatagramID) { if c.handshakeComplete { panic("shouldn't queue undecryptable packets after handshake completion") } if len(c.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: pt, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, Trigger: qlog.PacketDropDOSPrevention, }) } c.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return } c.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) if c.qlogger != nil { c.qlogger.RecordEvent(qlog.PacketBuffered{ Header: qlog.PacketHeader{ PacketType: pt, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: datagramID, }) } c.undecryptablePackets = append(c.undecryptablePackets, receivedPacketWithDatagramID{receivedPacket: p, datagramID: datagramID}) } func (c *Conn) queueControlFrame(f wire.Frame) { c.framer.QueueControlFrame(f) c.scheduleSending() } func (c *Conn) onHasConnectionData() { c.scheduleSending() } func (c *Conn) onHasStreamData(id protocol.StreamID, str *SendStream) { c.framer.AddActiveStream(id, str) c.scheduleSending() } func (c *Conn) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) { c.framer.AddStreamWithControlFrames(id, str) c.scheduleSending() } func (c *Conn) onStreamCompleted(id protocol.StreamID) { if err := c.streamsMap.DeleteStream(id); err != nil { c.closeLocal(err) } c.framer.RemoveActiveStream(id) } // SendDatagram sends a message using a QUIC datagram, as specified in RFC 9221, // if the peer enabled datagram support. // There is no delivery guarantee for DATAGRAM frames, they are not retransmitted if lost. // The payload of the datagram needs to fit into a single QUIC packet. // In addition, a datagram may be dropped before being sent out if the available packet size suddenly decreases. // If the payload is too large to be sent at the current time, a DatagramTooLargeError is returned. func (c *Conn) SendDatagram(p []byte) error { if !c.supportsDatagrams() { return errors.New("datagram support disabled") } f := &wire.DatagramFrame{DataLenPresent: true} // The payload size estimate is conservative. // Under many circumstances we could send a few more bytes. maxDataLen := min( f.MaxDataLen(c.peerParams.MaxDatagramFrameSize, c.version), protocol.ByteCount(c.currentMTUEstimate.Load()), ) if protocol.ByteCount(len(p)) > maxDataLen { return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)} } f.Data = make([]byte, len(p)) copy(f.Data, p) return c.datagramQueue.Add(f) } // ReceiveDatagram gets a message received in a QUIC datagram, as specified in RFC 9221. func (c *Conn) ReceiveDatagram(ctx context.Context) ([]byte, error) { if !c.config.EnableDatagrams { return nil, errors.New("datagram support disabled") } return c.datagramQueue.Receive(ctx) } // LocalAddr returns the local address of the QUIC connection. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr returns the remote address of the QUIC connection. func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // getPathManager lazily initializes the Conn's pathManagerOutgoing. // May create multiple pathManagerOutgoing objects if called concurrently. func (c *Conn) getPathManager() *pathManagerOutgoing { old := c.pathManagerOutgoing.Load() if old != nil { // Path manager is already initialized return old } // Initialize the path manager new := newPathManagerOutgoing( c.connIDManager.GetConnIDForPath, c.connIDManager.RetireConnIDForPath, c.scheduleSending, ) if c.pathManagerOutgoing.CompareAndSwap(old, new) { return new } // Swap failed. A concurrent writer wrote first, use their value. return c.pathManagerOutgoing.Load() } func (c *Conn) AddPath(t *Transport) (*Path, error) { if c.perspective == protocol.PerspectiveServer { return nil, errors.New("server cannot initiate connection migration") } if c.peerParams.DisableActiveMigration { return nil, errors.New("server disabled connection migration") } if err := t.init(false); err != nil { return nil, err } return c.getPathManager().NewPath( t, 200*time.Millisecond, // initial RTT estimate func() { runner := (*packetHandlerMap)(t) c.connIDGenerator.AddConnRunner( runner, connRunnerCallbacks{ AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, c) }, RemoveConnectionID: runner.Remove, ReplaceWithClosed: runner.ReplaceWithClosed, }, ) }, ), nil } // HandshakeComplete blocks until the handshake completes (or fails). // For the client, data sent before completion of the handshake is encrypted with 0-RTT keys. // For the server, data sent before completion of the handshake is encrypted with 1-RTT keys, // however the client's identity is only verified once the handshake completes. func (c *Conn) HandshakeComplete() <-chan struct{} { return c.handshakeCompleteChan } // QlogTrace returns the qlog trace of the QUIC connection. // It is nil if qlog is not enabled. func (c *Conn) QlogTrace() qlogwriter.Trace { return c.qlogTrace } // NextConnection transitions a connection to be usable after a 0-RTT rejection. // It waits for the handshake to complete and then enables the connection for normal use. // This should be called when the server rejects 0-RTT and the application receives // [Err0RTTRejected] errors. // // Note that 0-RTT rejection invalidates all data sent in 0-RTT packets. It is the // application's responsibility to handle this (for example by resending the data). func (c *Conn) NextConnection(ctx context.Context) (*Conn, error) { // The handshake might fail after the server rejected 0-RTT. // This could happen if the Finished message is malformed or never received. select { case <-ctx.Done(): return nil, context.Cause(ctx) case <-c.Context().Done(): case <-c.HandshakeComplete(): c.streamsMap.UseResetMaps() } return c, nil } // estimateMaxPayloadSize estimates the maximum payload size for short header packets. // It is not very sophisticated: it just subtracts the size of header (assuming the maximum // connection ID length), and the size of the encryption tag. func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount { return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */ } quic-go-0.59.0/connection_logging.go000066400000000000000000000232211513066070600173510ustar00rootroot00000000000000package quic import ( "net" "net/netip" "slices" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" ) // ConvertFrame converts a wire.Frame into a logging.Frame. // This makes it possible for external packages to access the frames. // Furthermore, it removes the data slices from CRYPTO and STREAM frames. func toQlogFrame(frame wire.Frame) qlog.Frame { switch f := frame.(type) { case *wire.AckFrame: // We use a pool for ACK frames. // Implementations of the tracer interface may hold on to frames, so we need to make a copy here. return qlog.Frame{Frame: toQlogAckFrame(f)} case *wire.CryptoFrame: return qlog.Frame{ Frame: &qlog.CryptoFrame{ Offset: int64(f.Offset), Length: int64(len(f.Data)), }, } case *wire.StreamFrame: return qlog.Frame{ Frame: &qlog.StreamFrame{ StreamID: f.StreamID, Offset: int64(f.Offset), Length: int64(f.DataLen()), Fin: f.Fin, }, } case *wire.DatagramFrame: return qlog.Frame{ Frame: &qlog.DatagramFrame{ Length: int64(len(f.Data)), }, } default: return qlog.Frame{Frame: frame} } } func toQlogAckFrame(f *wire.AckFrame) *qlog.AckFrame { ack := &qlog.AckFrame{ AckRanges: slices.Clone(f.AckRanges), DelayTime: f.DelayTime, ECNCE: f.ECNCE, ECT0: f.ECT0, ECT1: f.ECT1, } return ack } func (c *Conn) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN, datagramID qlog.DatagramID) { // quic-go logging if c.logger.Debug() { p.header.Log(c.logger) if p.ack != nil { wire.LogFrame(c.logger, p.ack, true) } for _, frame := range p.frames { wire.LogFrame(c.logger, frame.Frame, true) } for _, frame := range p.streamFrames { wire.LogFrame(c.logger, frame.Frame, true) } } // tracing if c.qlogger != nil { numFrames := len(p.frames) + len(p.streamFrames) if p.ack != nil { numFrames++ } frames := make([]qlog.Frame, 0, numFrames) if p.ack != nil { frames = append(frames, toQlogFrame(p.ack)) } for _, f := range p.frames { frames = append(frames, toQlogFrame(f.Frame)) } for _, f := range p.streamFrames { frames = append(frames, toQlogFrame(f.Frame)) } c.qlogger.RecordEvent(qlog.PacketSent{ Header: qlog.PacketHeader{ PacketType: toQlogPacketType(p.header.Type), KeyPhaseBit: p.header.KeyPhase, PacketNumber: p.header.PacketNumber, Version: p.header.Version, SrcConnectionID: p.header.SrcConnectionID, DestConnectionID: p.header.DestConnectionID, }, Raw: qlog.RawInfo{ Length: int(p.length), PayloadLength: int(p.header.Length), }, DatagramID: datagramID, Frames: frames, ECN: toQlogECN(ecn), }) } } func (c *Conn) logShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, size protocol.ByteCount) { c.logShortHeaderPacketWithDatagramID(p, ecn, size, false, 0) } func (c *Conn) logShortHeaderPacketWithDatagramID(p shortHeaderPacket, ecn protocol.ECN, size protocol.ByteCount, isCoalesced bool, datagramID qlog.DatagramID) { if c.logger.Debug() && !isCoalesced { c.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", p.PacketNumber, size, c.logID, ecn) } // quic-go logging if c.logger.Debug() { wire.LogShortHeader(c.logger, p.DestConnID, p.PacketNumber, p.PacketNumberLen, p.KeyPhase) if p.Ack != nil { wire.LogFrame(c.logger, p.Ack, true) } for _, f := range p.Frames { wire.LogFrame(c.logger, f.Frame, true) } for _, f := range p.StreamFrames { wire.LogFrame(c.logger, f.Frame, true) } } // tracing if c.qlogger != nil { numFrames := len(p.Frames) + len(p.StreamFrames) if p.Ack != nil { numFrames++ } fs := make([]qlog.Frame, 0, numFrames) if p.Ack != nil { fs = append(fs, toQlogFrame(p.Ack)) } for _, f := range p.Frames { fs = append(fs, toQlogFrame(f.Frame)) } for _, f := range p.StreamFrames { fs = append(fs, toQlogFrame(f.Frame)) } c.qlogger.RecordEvent(qlog.PacketSent{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType1RTT, KeyPhaseBit: p.KeyPhase, PacketNumber: p.PacketNumber, Version: c.version, DestConnectionID: p.DestConnID, }, Raw: qlog.RawInfo{ Length: int(size), PayloadLength: int(size - wire.ShortHeaderLen(p.DestConnID, p.PacketNumberLen)), }, DatagramID: datagramID, Frames: fs, ECN: toQlogECN(ecn), }) } } func (c *Conn) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) { var datagramID qlog.DatagramID if c.qlogger != nil { datagramID = qlog.CalculateDatagramID(packet.buffer.Data) } if c.logger.Debug() { // There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // during which we might call PackCoalescedPacket but just pack a short header packet. if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil { c.logShortHeaderPacketWithDatagramID( *packet.shortHdrPacket, ecn, packet.shortHdrPacket.Length, false, datagramID, ) return } if len(packet.longHdrPackets) > 1 { c.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), c.logID) } else { c.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), c.logID, packet.longHdrPackets[0].EncryptionLevel()) } } for _, p := range packet.longHdrPackets { c.logLongHeaderPacket(p, ecn, datagramID) } if p := packet.shortHdrPacket; p != nil { c.logShortHeaderPacketWithDatagramID(*p, ecn, p.Length, true, datagramID) } } func (c *Conn) qlogTransportParameters(tp *wire.TransportParameters, sentBy protocol.Perspective, restore bool) { ev := qlog.ParametersSet{ Restore: restore, OriginalDestinationConnectionID: tp.OriginalDestinationConnectionID, InitialSourceConnectionID: tp.InitialSourceConnectionID, RetrySourceConnectionID: tp.RetrySourceConnectionID, StatelessResetToken: tp.StatelessResetToken, DisableActiveMigration: tp.DisableActiveMigration, MaxIdleTimeout: tp.MaxIdleTimeout, MaxUDPPayloadSize: tp.MaxUDPPayloadSize, AckDelayExponent: tp.AckDelayExponent, MaxAckDelay: tp.MaxAckDelay, ActiveConnectionIDLimit: tp.ActiveConnectionIDLimit, InitialMaxData: tp.InitialMaxData, InitialMaxStreamDataBidiLocal: tp.InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote: tp.InitialMaxStreamDataBidiRemote, InitialMaxStreamDataUni: tp.InitialMaxStreamDataUni, InitialMaxStreamsBidi: int64(tp.MaxBidiStreamNum), InitialMaxStreamsUni: int64(tp.MaxUniStreamNum), MaxDatagramFrameSize: tp.MaxDatagramFrameSize, EnableResetStreamAt: tp.EnableResetStreamAt, } if sentBy == c.perspective { ev.Initiator = qlog.InitiatorLocal } else { ev.Initiator = qlog.InitiatorRemote } if tp.PreferredAddress != nil { ev.PreferredAddress = &qlog.PreferredAddress{ IPv4: tp.PreferredAddress.IPv4, IPv6: tp.PreferredAddress.IPv6, ConnectionID: tp.PreferredAddress.ConnectionID, StatelessResetToken: tp.PreferredAddress.StatelessResetToken, } } c.qlogger.RecordEvent(ev) } func toQlogECN(ecn protocol.ECN) qlog.ECN { //nolint:exhaustive // only need to handle the 3 valid values switch ecn { case protocol.ECT0: return qlog.ECT0 case protocol.ECT1: return qlog.ECT1 case protocol.ECNCE: return qlog.ECNCE default: return qlog.ECNUnsupported } } func toQlogPacketType(pt protocol.PacketType) qlog.PacketType { var qpt qlog.PacketType switch pt { case protocol.PacketTypeInitial: qpt = qlog.PacketTypeInitial case protocol.PacketTypeHandshake: qpt = qlog.PacketTypeHandshake case protocol.PacketType0RTT: qpt = qlog.PacketType0RTT case protocol.PacketTypeRetry: qpt = qlog.PacketTypeRetry } return qpt } func toPathEndpointInfo(addr *net.UDPAddr) qlog.PathEndpointInfo { if addr == nil { return qlog.PathEndpointInfo{} } var info qlog.PathEndpointInfo if addr.IP == nil || addr.IP.To4() != nil { addrPort := netip.AddrPortFrom(netip.AddrFrom4([4]byte(addr.IP.To4())), uint16(addr.Port)) if addrPort.IsValid() { info.IPv4 = addrPort } } else { addrPort := netip.AddrPortFrom(netip.AddrFrom16([16]byte(addr.IP.To16())), uint16(addr.Port)) if addrPort.IsValid() { info.IPv6 = addrPort } } return info } // startedConnectionEvent builds a StartedConnection event using consistent logic // for both endpoints. If the local address is unspecified (e.g., dual-stack // listener), it selects the family based on the remote address and uses the // unspecified address of that family with the local port. func startedConnectionEvent(local, remote *net.UDPAddr) qlog.StartedConnection { var localInfo, remoteInfo qlog.PathEndpointInfo if remote != nil { remoteInfo = toPathEndpointInfo(remote) } if local != nil { if local.IP == nil || local.IP.IsUnspecified() { // Choose local family based on the remote address family. if remote != nil && remote.IP.To4() != nil { ap := netip.AddrPortFrom(netip.AddrFrom4([4]byte{}), uint16(local.Port)) if ap.IsValid() { localInfo.IPv4 = ap } } else if remote != nil && remote.IP.To16() != nil && remote.IP.To4() == nil { ap := netip.AddrPortFrom(netip.AddrFrom16([16]byte{}), uint16(local.Port)) if ap.IsValid() { localInfo.IPv6 = ap } } } else { localInfo = toPathEndpointInfo(local) } } return qlog.StartedConnection{Local: localInfo, Remote: remoteInfo} } quic-go-0.59.0/connection_logging_test.go000066400000000000000000000076651513066070600204260ustar00rootroot00000000000000package quic import ( "net" "net/netip" "testing" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/stretchr/testify/require" ) func TestConnectionLoggingCryptoFrame(t *testing.T) { f := toQlogFrame(&wire.CryptoFrame{ Offset: 1234, Data: []byte("foobar"), }) require.Equal(t, &qlog.CryptoFrame{ Offset: 1234, Length: 6, }, f.Frame) } func TestConnectionLoggingStreamFrame(t *testing.T) { f := toQlogFrame(&wire.StreamFrame{ StreamID: 42, Offset: 1234, Data: []byte("foo"), Fin: true, }) require.Equal(t, &qlog.StreamFrame{ StreamID: 42, Offset: 1234, Length: 3, Fin: true, }, f.Frame) } func TestConnectionLoggingAckFrame(t *testing.T) { ack := &wire.AckFrame{ AckRanges: []wire.AckRange{ {Smallest: 1, Largest: 3}, {Smallest: 6, Largest: 7}, }, DelayTime: 42, ECNCE: 123, ECT0: 456, ECT1: 789, } f := toQlogFrame(ack) // now modify the ACK range in the original frame ack.AckRanges[0].Smallest = 2 require.Equal(t, &qlog.AckFrame{ AckRanges: []wire.AckRange{ {Smallest: 1, Largest: 3}, // unchanged, since the ACK ranges were cloned {Smallest: 6, Largest: 7}, }, DelayTime: 42, ECNCE: 123, ECT0: 456, ECT1: 789, }, f.Frame) } func TestConnectionLoggingDatagramFrame(t *testing.T) { f := toQlogFrame(&wire.DatagramFrame{Data: []byte("foobar")}) require.Equal(t, &qlog.DatagramFrame{Length: 6}, f.Frame) } func TestConnectionLoggingOtherFrames(t *testing.T) { f := toQlogFrame(&wire.MaxDataFrame{MaximumData: 1234}) require.Equal(t, &qlog.MaxDataFrame{MaximumData: 1234}, f.Frame) } func TestConnectionLoggingStartedConnectionEvent(t *testing.T) { tests := []struct { name string local *net.UDPAddr remote *net.UDPAddr wantLocalIP string wantLocalPort uint16 wantRemote netip.AddrPort }{ { name: "unspecified local, remote IPv4 -> 0.0.0.0", local: &net.UDPAddr{Port: 58451}, remote: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6121}, wantLocalIP: "0.0.0.0", wantLocalPort: 58451, wantRemote: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 6121), }, { name: "unspecified local, remote IPv6 -> ::", local: &net.UDPAddr{Port: 4242}, remote: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 6121}, wantLocalIP: "::", wantLocalPort: 4242, wantRemote: func() netip.AddrPort { a, _ := netip.ParseAddr("2001:db8::1"); return netip.AddrPortFrom(a, 6121) }(), }, { name: "specified local IPv4", local: &net.UDPAddr{IP: net.IPv4(192, 168, 1, 10), Port: 9999}, remote: &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 1234}, wantLocalIP: "192.168.1.10", wantLocalPort: 9999, wantRemote: netip.AddrPortFrom(netip.AddrFrom4([4]byte{10, 0, 0, 1}), 1234), }, { name: "specified local IPv6", local: &net.UDPAddr{IP: net.ParseIP("fe80::1"), Port: 999}, remote: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 6121}, wantLocalIP: "fe80::1", wantLocalPort: 999, wantRemote: func() netip.AddrPort { a, _ := netip.ParseAddr("2001:db8::1"); return netip.AddrPortFrom(a, 6121) }(), }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ev := startedConnectionEvent(tc.local, tc.remote) var gotIP string var gotPort uint16 if ev.Local.IPv4.IsValid() { gotIP = ev.Local.IPv4.Addr().String() gotPort = ev.Local.IPv4.Port() } else if ev.Local.IPv6.IsValid() { gotIP = ev.Local.IPv6.Addr().String() gotPort = ev.Local.IPv6.Port() } require.Equal(t, tc.wantLocalIP, gotIP) require.Equal(t, tc.wantLocalPort, gotPort) var gotRemote netip.AddrPort if ev.Remote.IPv4.IsValid() { gotRemote = ev.Remote.IPv4 } else if ev.Remote.IPv6.IsValid() { gotRemote = ev.Remote.IPv6 } require.Equal(t, tc.wantRemote, gotRemote) }) } } quic-go-0.59.0/connection_test.go000066400000000000000000003420501513066070600167060ustar00rootroot00000000000000package quic import ( "bytes" "context" "crypto/rand" "crypto/tls" "errors" "fmt" "net" "net/netip" "strconv" "testing" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/mocks" mockackhandler "github.com/quic-go/quic-go/internal/mocks/ackhandler" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type testConnectionOpt func(*Conn) func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt { return func(conn *Conn) { conn.cryptoStreamHandler = cs } } func connectionOptConnFlowController(cfc flowcontrol.ConnectionFlowController) testConnectionOpt { return func(conn *Conn) { conn.connFlowController = cfc } } func connectionOptTracer(r qlogwriter.Recorder) testConnectionOpt { return func(conn *Conn) { conn.qlogger = r } } func connectionOptSentPacketHandler(sph ackhandler.SentPacketHandler) testConnectionOpt { return func(conn *Conn) { conn.sentPacketHandler = sph } } func connectionOptUnpacker(u unpacker) testConnectionOpt { return func(conn *Conn) { conn.unpacker = u } } func connectionOptSender(s sender) testConnectionOpt { return func(conn *Conn) { conn.sendQueue = s } } func connectionOptHandshakeConfirmed() testConnectionOpt { return func(conn *Conn) { conn.handshakeComplete = true conn.handshakeConfirmed = true } } func connectionOptRTT(rtt time.Duration) testConnectionOpt { rttStats := utils.NewRTTStats() rttStats.UpdateRTT(rtt, 0) return func(conn *Conn) { conn.rttStats = rttStats } } func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt { return func(conn *Conn) { conn.retrySrcConnID = &rcid } } type testConnection struct { conn *Conn connRunner *MockConnRunner sendConn *MockSendConn packer *MockPacker destConnID protocol.ConnectionID srcConnID protocol.ConnectionID remoteAddr *net.UDPAddr } func (tc *testConnection) receivedPacketHandler() *ackhandler.ReceivedPacketHandler { return &tc.conn.receivedPacketHandler } func newServerTestConnection( t *testing.T, mockCtrl *gomock.Controller, config *Config, gso bool, opts ...testConnectionOpt, ) *testConnection { if mockCtrl == nil { mockCtrl = gomock.NewController(t) } remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} connRunner := NewMockConnRunner(mockCtrl) sendConn := NewMockSendConn(mockCtrl) sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes() sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() packer := NewMockPacker(mockCtrl) b := make([]byte, 12) rand.Read(b) origDestConnID := protocol.ParseConnectionID(b[:6]) srcConnID := protocol.ParseConnectionID(b[6:12]) ctx, cancel := context.WithCancelCause(context.Background()) if config == nil { config = &Config{DisablePathMTUDiscovery: true} } wc := newConnection( ctx, cancel, sendConn, connRunner, origDestConnID, nil, protocol.ConnectionID{}, protocol.ConnectionID{}, srcConnID, &protocol.DefaultConnectionIDGenerator{}, newStatelessResetter(nil), populateConfig(config), &tls.Config{}, handshake.NewTokenGenerator(handshake.TokenProtectorKey{}), false, 1337*time.Millisecond, nil, utils.DefaultLogger, protocol.Version1, ) require.Nil(t, wc.testHooks) conn := wc.Conn conn.packer = packer for _, opt := range opts { opt(conn) } return &testConnection{ conn: conn, connRunner: connRunner, sendConn: sendConn, packer: packer, destConnID: origDestConnID, srcConnID: srcConnID, remoteAddr: remoteAddr, } } func newClientTestConnection( t *testing.T, mockCtrl *gomock.Controller, config *Config, enable0RTT bool, opts ...testConnectionOpt, ) *testConnection { if mockCtrl == nil { mockCtrl = gomock.NewController(t) } remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} connRunner := NewMockConnRunner(mockCtrl) sendConn := NewMockSendConn(mockCtrl) sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes() sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() packer := NewMockPacker(mockCtrl) b := make([]byte, 12) rand.Read(b) destConnID := protocol.ParseConnectionID(b[:6]) srcConnID := protocol.ParseConnectionID(b[6:12]) if config == nil { config = &Config{DisablePathMTUDiscovery: true} } conn := newClientConnection( context.Background(), sendConn, connRunner, destConnID, srcConnID, &protocol.DefaultConnectionIDGenerator{}, newStatelessResetter(nil), populateConfig(config), &tls.Config{ServerName: "quic-go.net"}, 0, enable0RTT, false, nil, utils.DefaultLogger, protocol.Version1, ) require.Nil(t, conn.testHooks) conn.packer = packer for _, opt := range opts { opt(conn.Conn) } return &testConnection{ conn: conn.Conn, connRunner: connRunner, sendConn: sendConn, packer: packer, destConnID: destConnID, srcConnID: srcConnID, } } func TestConnectionHandleStreamRelatedFrames(t *testing.T) { const id protocol.StreamID = 5 connID := protocol.ConnectionID{} tests := []struct { name string frame wire.Frame }{ {name: "RESET_STREAM", frame: &wire.ResetStreamFrame{StreamID: id, ErrorCode: 42, FinalSize: 1337}}, {name: "STOP_SENDING", frame: &wire.StopSendingFrame{StreamID: id, ErrorCode: 42}}, {name: "MAX_STREAM_DATA", frame: &wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1337}}, {name: "STREAM_DATA_BLOCKED", frame: &wire.StreamDataBlockedFrame{StreamID: id, MaximumStreamData: 42}}, {name: "STREAM_FRAME", frame: &wire.StreamFrame{StreamID: id, Data: []byte{1, 2, 3, 4, 5, 6, 7, 8}, Offset: 1337}}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { tc := newServerTestConnection(t, gomock.NewController(t), nil, false) data, err := test.frame.Append(nil, protocol.Version1) require.NoError(t, err) _, _, _, err = tc.conn.handleFrames(data, connID, protocol.Encryption1RTT, nil, monotime.Now()) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) }) } } func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) { mockCtrl := gomock.NewController(t) connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, utils.NewRTTStats(), utils.DefaultLogger) require.Zero(t, connFC.SendWindowSize()) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC)) now := monotime.Now() connID := protocol.ConnectionID{} // MAX_DATA frame _, err := tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now) require.NoError(t, err) require.Equal(t, protocol.ByteCount(1337), connFC.SendWindowSize()) // DATA_BLOCKED frame _, err = tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now) require.NoError(t, err) } func TestConnectionServerInvalidFrames(t *testing.T) { mockCtrl := gomock.NewController(t) tc := newServerTestConnection(t, mockCtrl, nil, false) for _, test := range []struct { Name string Frame wire.Frame }{ {Name: "NEW_TOKEN", Frame: &wire.NewTokenFrame{Token: []byte("foobar")}}, {Name: "HANDSHAKE_DONE", Frame: &wire.HandshakeDoneFrame{}}, {Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, } { t.Run(test.Name, func(t *testing.T) { _, err := tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, monotime.Now()) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) }) } } func TestConnectionClose(t *testing.T) { t.Run("transport error", func(t *testing.T) { expectedErr := &qerr.TransportError{ ErrorCode: 1337, FrameType: 42, ErrorMessage: "foobar", } testConnectionClose(t, false, expectedErr) }) t.Run("application error", func(t *testing.T) { expectedErr := &qerr.ApplicationError{ ErrorCode: 1337, ErrorMessage: "foobar", } testConnectionClose(t, true, expectedErr) }) } func testConnectionClose(t *testing.T, useApplicationClose bool, expectedErr error) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) errChan := make(chan error, 1) tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() b := getPacketBuffer() b.Data = append(b.Data, []byte("connection close")...) if useApplicationClose { tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil) } else { tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil) } tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { errChan <- tc.conn.run() }() tc.conn.closeLocal(expectedErr) synctest.Wait() var want qlog.ConnectionClosed if useApplicationClose { code := expectedErr.(*qerr.ApplicationError).ErrorCode want = qlog.ConnectionClosed{ Initiator: qlog.InitiatorLocal, ApplicationError: &code, Reason: expectedErr.(*qerr.ApplicationError).ErrorMessage, } } else { code := expectedErr.(*qerr.TransportError).ErrorCode want = qlog.ConnectionClosed{ Initiator: qlog.InitiatorLocal, ConnectionError: &code, Reason: expectedErr.(*qerr.TransportError).ErrorMessage, } } require.Equal(t, []qlogwriter.Event{want}, eventRecorder.Events(qlog.ConnectionClosed{}), ) eventRecorder.Clear() select { case err := <-errChan: require.ErrorIs(t, err, expectedErr) default: t.Fatal("connection was not closed") } // further calls to CloseWithError don't do anything tc.conn.CloseWithError(42, "another error") require.Empty(t, eventRecorder.Events(qlog.ConnectionClosed{})) }) } func TestConnectionStatelessReset(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) errChan := make(chan error, 1) tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() go func() { errChan <- tc.conn.run() }() tc.conn.destroy(&StatelessResetError{}) synctest.Wait() require.Equal(t, []qlogwriter.Event{qlog.ConnectionClosed{Initiator: qlog.InitiatorLocal, Trigger: qlog.ConnectionCloseTriggerStatelessReset}}, eventRecorder.Events(qlog.ConnectionClosed{}), ) }) } func getLongHeaderPacket(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket { t.Helper() b, err := extHdr.Append(nil, protocol.Version1) require.NoError(t, err) return receivedPacket{ remoteAddr: remoteAddr, data: append(b, data...), buffer: getPacketBuffer(), rcvTime: monotime.Now(), } } func getShortHeaderPacket(t *testing.T, remoteAddr net.Addr, connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket { t.Helper() b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne) require.NoError(t, err) return receivedPacket{ remoteAddr: remoteAddr, data: append(b, data...), buffer: getPacketBuffer(), rcvTime: monotime.Now(), } } func TestConnectionServerInvalidPackets(t *testing.T) { t.Run("Retry", func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) p := getLongHeaderPacket(t, tc.remoteAddr, &wire.ExtendedHeader{Header: wire.Header{ Type: protocol.PacketTypeRetry, DestConnectionID: tc.conn.origDestConnID, SrcConnectionID: tc.srcConnID, Version: tc.conn.version, Token: []byte("foobar"), }}, make([]byte, 16), /* Retry integrity tag */ ) wasProcessed, err := tc.conn.handleOnePacket(p, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: tc.srcConnID, DestConnectionID: tc.conn.origDestConnID, Version: tc.conn.version, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) }) t.Run("version negotiation", func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) b := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()), protocol.ArbitraryLenConnectionID(tc.conn.origDestConnID.Bytes()), []Version{Version1}, ) wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{data: b, buffer: getPacketBuffer()}, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: len(b)}, Trigger: qlog.PacketDropUnexpectedPacket, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) }) t.Run("unsupported version", func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) p := getLongHeaderPacket(t, tc.remoteAddr, &wire.ExtendedHeader{ Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: 1234}, PacketNumberLen: protocol.PacketNumberLen2, }, nil, ) wasProcessed, err := tc.conn.handleOnePacket(p, 42) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{Version: 1234}, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: 42, Trigger: qlog.PacketDropUnsupportedVersion, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) }) t.Run("invalid header", func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) p := getLongHeaderPacket(t, tc.remoteAddr, &wire.ExtendedHeader{ Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: Version1}, PacketNumberLen: protocol.PacketNumberLen2, }, nil, ) p.data[0] ^= 0x40 // unset the QUIC bit wasProcessed, err := tc.conn.handleOnePacket(p, 42) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{}, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: 42, Trigger: qlog.PacketDropHeaderParseError, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) }) } func TestConnectionClientDrop0RTT(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) p := getLongHeaderPacket(t, tc.remoteAddr, &wire.ExtendedHeader{ Header: wire.Header{Type: protocol.PacketType0RTT, Length: 2, Version: protocol.Version1}, PacketNumberLen: protocol.PacketNumberLen2, }, nil, ) wasProcessed, err := tc.conn.handleOnePacket(p, 1234) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, DatagramID: 1234, Trigger: qlog.PacketDropUnexpectedPacket, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) } func TestConnectionUnpacking(t *testing.T) { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptUnpacker(unpacker), connectionOptTracer(&eventRecorder), ) // receive a long header packet hdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: tc.srcConnID, Version: protocol.Version1, Length: 1, }, PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } unpackedHdr := *hdr unpackedHdr.PacketNumber = 0x1337 packet := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil) packet.ecn = protocol.ECNCE rcvTime := monotime.Now().Add(-10 * time.Second) packet.rcvTime = rcvTime unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionInitial, hdr: &unpackedHdr, data: []byte{0}, // one PADDING frame }, nil) wasProcessed, err := tc.conn.handleOnePacket(packet, 42) require.NoError(t, err) require.True(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, DestConnectionID: tc.srcConnID, PacketNumber: protocol.PacketNumber(0x1337), Version: protocol.Version1, }, Frames: []qlog.Frame{}, ECN: qlog.ECNCE, Raw: qlog.RawInfo{Length: int(packet.Size()), PayloadLength: 1}, DatagramID: 42, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketDropped{}), ) eventRecorder.Clear() // receive a duplicate of this packet packet = getLongHeaderPacket(t, tc.remoteAddr, hdr, nil) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionInitial, hdr: &unpackedHdr, data: []byte{0}, // one PADDING frame }, nil) wasProcessed, err = tc.conn.handleOnePacket(packet, 43) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, DestConnectionID: tc.srcConnID, PacketNumber: protocol.PacketNumber(0x1337), Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(packet.Size()), PayloadLength: 1}, DatagramID: 43, Trigger: qlog.PacketDropDuplicate, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketDropped{}), ) eventRecorder.Clear() // receive a short header packet packet = getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 0x37, nil) packet.ecn = protocol.ECT1 packet.rcvTime = rcvTime unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil, ) wasProcessed, err = tc.conn.handleOnePacket(packet, 0) require.NoError(t, err) require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType1RTT, DestConnectionID: tc.srcConnID, PacketNumber: protocol.PacketNumber(0x1337), KeyPhaseBit: protocol.KeyPhaseZero, }, Raw: qlog.RawInfo{Length: int(packet.Size())}, Frames: []qlog.Frame{}, ECN: qlog.ECT1, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketDropped{}), ) require.True(t, wasProcessed) } func TestConnectionUnpackCoalescedPacket(t *testing.T) { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptUnpacker(unpacker), connectionOptTracer(&eventRecorder), ) hdr1 := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: tc.srcConnID, Version: protocol.Version1, Length: 1, }, PacketNumber: 37, PacketNumberLen: protocol.PacketNumberLen1, } hdr2 := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: tc.srcConnID, Version: protocol.Version1, Length: 1, }, PacketNumber: 38, PacketNumberLen: protocol.PacketNumberLen1, } // add a packet with a different source connection ID incorrectSrcConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc}) hdr3 := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: incorrectSrcConnID, Version: protocol.Version1, Length: 1, }, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen1, } unpackedHdr1 := *hdr1 unpackedHdr1.PacketNumber = 1337 unpackedHdr2 := *hdr2 unpackedHdr2.PacketNumber = 1338 packet := getLongHeaderPacket(t, tc.remoteAddr, hdr1, nil) firstPacketLen := packet.Size() packet2 := getLongHeaderPacket(t, tc.remoteAddr, hdr2, nil) packet3 := getLongHeaderPacket(t, tc.remoteAddr, hdr3, nil) packet.data = append(packet.data, packet2.data...) packet.data = append(packet.data, packet3.data...) packet.ecn = protocol.ECT1 rcvTime := monotime.Now() packet.rcvTime = rcvTime unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionInitial, hdr: &unpackedHdr1, data: []byte{0}, // one PADDING frame }, nil) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, hdr: &unpackedHdr2, data: []byte{1}, // one PING frame }, nil) wasProcessed, err := tc.conn.handleOnePacket(packet, 42) require.NoError(t, err) require.True(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, DestConnectionID: tc.srcConnID, PacketNumber: protocol.PacketNumber(1337), Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(firstPacketLen), PayloadLength: 1}, DatagramID: 42, Frames: []qlog.Frame{}, ECN: qlog.ECT1, }, qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, DestConnectionID: tc.srcConnID, PacketNumber: protocol.PacketNumber(1338), Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(packet2.Size()), PayloadLength: 1}, DatagramID: 42, Frames: []qlog.Frame{{Frame: &wire.PingFrame{}}}, ECN: qlog.ECT1, }, qlog.PacketDropped{ Header: qlog.PacketHeader{DestConnectionID: incorrectSrcConnID}, Raw: qlog.RawInfo{Length: int(packet3.Size())}, DatagramID: 42, Trigger: qlog.PacketDropUnknownConnectionID, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketDropped{}), ) } func TestConnectionUnpackFailuresFatal(t *testing.T) { t.Run("other errors", func(t *testing.T) { require.ErrorIs(t, testConnectionUnpackFailureFatal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}), &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}, ) }) t.Run("invalid reserved bits", func(t *testing.T) { require.ErrorIs(t, testConnectionUnpackFailureFatal(t, wire.ErrInvalidReservedBits), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, ) }) } func testConnectionUnpackFailureFatal(t *testing.T, unpackErr error) error { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptUnpacker(unpacker), ) tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr) tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tc.conn.handlePacket(getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 0x42, nil)) select { case err := <-errChan: require.Error(t, err) return err case <-time.After(time.Second): t.Fatal("timeout") } return nil } func TestConnectionUnpackFailureDropped(t *testing.T) { t.Run("keys dropped", func(t *testing.T) { testConnectionUnpackFailureDropped(t, handshake.ErrKeysDropped, qlog.PacketDropKeyUnavailable) }) t.Run("decryption failed", func(t *testing.T) { testConnectionUnpackFailureDropped(t, handshake.ErrDecryptionFailed, qlog.PacketDropPayloadDecryptError) }) t.Run("header parse error", func(t *testing.T) { testConnectionUnpackFailureDropped(t, &headerParseError{err: assert.AnError}, qlog.PacketDropHeaderParseError) }) } func testConnectionUnpackFailureDropped(t *testing.T, unpackErr error, packetDropReason qlog.PacketDropReason) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptUnpacker(unpacker), connectionOptTracer(&eventRecorder), ) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() packet := getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 0x42, nil) tc.conn.handlePacket(packet) synctest.Wait() require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType1RTT, DestConnectionID: tc.srcConnID, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(packet.Size())}, Trigger: packetDropReason, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case <-errChan: default: t.Fatal("timeout") } }) } func TestConnectionMaxUnprocessedPackets(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) for range protocol.MaxConnUnprocessedPackets { // nothing here should block tc.conn.handlePacket(receivedPacket{data: []byte("foobar")}) } tc.conn.handlePacket(receivedPacket{data: []byte("foobar")}) synctest.Wait() require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Raw: qlog.RawInfo{Length: 6}, Trigger: qlog.PacketDropDOSPrevention, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) }) } func TestConnectionRemoteClose(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder unpacker := NewMockUnpacker(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), connectionOptUnpacker(unpacker), ) ccf, err := (&wire.ConnectionCloseFrame{ ErrorCode: uint64(qerr.StreamLimitError), ReasonPhrase: "foobar", }).Append(nil, protocol.Version1) require.NoError(t, err) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil) tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() p := getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 1, []byte("encrypted")) tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) synctest.Wait() expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, ErrorMessage: "foobar", Remote: true} select { case err := <-errChan: require.ErrorIs(t, err, expectedErr) default: t.Fatal("timeout") } code := expectedErr.ErrorCode require.Equal(t, []qlogwriter.Event{ qlog.ConnectionClosed{ Initiator: qlog.InitiatorRemote, ConnectionError: &code, Reason: expectedErr.ErrorMessage, }, }, eventRecorder.Events(qlog.ConnectionClosed{}), ) }) } func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) { synctest.Test(t, func(t *testing.T) { const timeout = 7 * time.Second mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, &Config{HandshakeIdleTimeout: timeout}, false, connectionOptTracer(&eventRecorder), ) tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes() tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() start := monotime.Now() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, &IdleTimeoutError{}) require.Equal(t, timeout, monotime.Since(start)) case <-time.After(timeout + time.Nanosecond): t.Fatal("timeout") } require.Equal(t, []qlogwriter.Event{ qlog.ConnectionClosed{ Initiator: qlog.InitiatorLocal, Trigger: qlog.ConnectionCloseTriggerIdleTimeout, }, }, eventRecorder.Events(qlog.ConnectionClosed{}), ) }) } func TestConnectionHandshakeIdleTimeout(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, &Config{HandshakeIdleTimeout: 7 * time.Second}, false, connectionOptTracer(&eventRecorder), func(c *Conn) { c.creationTime = monotime.Now().Add(-20 * time.Second) }, ) tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes() tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, &HandshakeTimeoutError{}) case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, []qlogwriter.Event{ qlog.ConnectionClosed{ Initiator: qlog.InitiatorLocal, Trigger: qlog.ConnectionCloseTriggerIdleTimeout, }, }, eventRecorder.Events(qlog.ConnectionClosed{}), ) }) } func TestConnectionTransportParameters(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, utils.NewRTTStats(), utils.DefaultLogger) require.Zero(t, connFC.SendWindowSize()) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), connectionOptConnFlowController(connFC), ) _, err := tc.conn.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) _, err = tc.conn.OpenUniStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) params := &wire.TransportParameters{ MaxIdleTimeout: 90 * time.Second, InitialMaxStreamDataBidiLocal: 0x5000, InitialMaxData: 1337, ActiveConnectionIDLimit: 3, // marshaling always sets it to this value MaxUDPPayloadSize: protocol.MaxPacketBufferSize, OriginalDestinationConnectionID: tc.destConnID, MaxBidiStreamNum: 1, MaxUniStreamNum: 1, } require.NoError(t, tc.conn.handleTransportParameters(params)) require.Equal(t, protocol.ByteCount(1337), connFC.SendWindowSize()) _, err = tc.conn.OpenStream() require.NoError(t, err) _, err = tc.conn.OpenUniStream() require.NoError(t, err) require.Equal(t, []qlogwriter.Event{ qlog.ParametersSet{ Initiator: qlog.InitiatorRemote, MaxIdleTimeout: 90 * time.Second, InitialMaxStreamDataBidiLocal: 0x5000, InitialMaxData: 1337, ActiveConnectionIDLimit: 3, // marshaling always sets it to this value MaxUDPPayloadSize: protocol.MaxPacketBufferSize, OriginalDestinationConnectionID: tc.destConnID, InitialMaxStreamsBidi: 1, InitialMaxStreamsUni: 1, }, }, eventRecorder.Events(qlog.ParametersSet{}), ) } func TestConnectionHandleMaxStreamsFrame(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, utils.NewRTTStats(), utils.DefaultLogger) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC)) tc.conn.handleTransportParameters(&wire.TransportParameters{}) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() uniStreamChan := make(chan error) go func() { _, err := tc.conn.OpenUniStreamSync(ctx) uniStreamChan <- err }() bidiStreamChan := make(chan error) go func() { _, err := tc.conn.OpenStreamSync(ctx) bidiStreamChan <- err }() synctest.Wait() select { case <-uniStreamChan: t.Fatal("uni stream should be blocked") case <-bidiStreamChan: t.Fatal("bidi stream should be blocked") default: } // MAX_STREAMS frame for bidirectional stream _, err := tc.conn.handleFrame( &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}, protocol.Encryption1RTT, protocol.ConnectionID{}, monotime.Now(), ) require.NoError(t, err) synctest.Wait() select { case <-uniStreamChan: t.Fatal("uni stream should be blocked") default: } select { case err := <-bidiStreamChan: require.NoError(t, err) default: t.Fatal("bidi stream should be unblocked") } // MAX_STREAMS frame for bidirectional stream _, err = tc.conn.handleFrame( &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10}, protocol.Encryption1RTT, protocol.ConnectionID{}, monotime.Now(), ) require.NoError(t, err) synctest.Wait() select { case err := <-uniStreamChan: require.NoError(t, err) default: t.Fatal("timeout") } }) } func TestConnectionTransportParameterValidationFailureServer(t *testing.T) { tc := newServerTestConnection(t, nil, nil, false) err := tc.conn.handleTransportParameters(&wire.TransportParameters{ InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), }) assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError}) assert.ErrorContains(t, err, "expected initial_source_connection_id to equal") } func TestConnectionTransportParameterValidationFailureClient(t *testing.T) { t.Run("initial_source_connection_id", func(t *testing.T) { tc := newClientTestConnection(t, nil, nil, false) err := tc.conn.handleTransportParameters(&wire.TransportParameters{ InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), }) assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError}) assert.ErrorContains(t, err, "expected initial_source_connection_id to equal") }) t.Run("original_destination_connection_id", func(t *testing.T) { tc := newClientTestConnection(t, nil, nil, false) err := tc.conn.handleTransportParameters(&wire.TransportParameters{ InitialSourceConnectionID: tc.destConnID, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), }) assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError}) assert.ErrorContains(t, err, "expected original_destination_connection_id to equal") }) t.Run("retry_source_connection_id if no retry", func(t *testing.T) { tc := newClientTestConnection(t, nil, nil, false) rcid := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) params := &wire.TransportParameters{ InitialSourceConnectionID: tc.destConnID, OriginalDestinationConnectionID: tc.destConnID, RetrySourceConnectionID: &rcid, } err := tc.conn.handleTransportParameters(params) assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError}) assert.ErrorContains(t, err, "received retry_source_connection_id, although no Retry was performed") }) t.Run("retry_source_connection_id missing", func(t *testing.T) { tc := newClientTestConnection(t, nil, nil, false, connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})), ) params := &wire.TransportParameters{ InitialSourceConnectionID: tc.destConnID, OriginalDestinationConnectionID: tc.destConnID, } err := tc.conn.handleTransportParameters(params) assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError}) assert.ErrorContains(t, err, "missing retry_source_connection_id") }) t.Run("retry_source_connection_id incorrect", func(t *testing.T) { tc := newClientTestConnection(t, nil, nil, false, connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})), ) wrongCID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) params := &wire.TransportParameters{ InitialSourceConnectionID: tc.destConnID, OriginalDestinationConnectionID: tc.destConnID, RetrySourceConnectionID: &wrongCID, } err := tc.conn.handleTransportParameters(params) assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError}) assert.ErrorContains(t, err, "expected retry_source_connection_id to equal") }) } func TestConnectionHandshakeServer(t *testing.T) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := NewMockUnpacker(mockCtrl) tc := newServerTestConnection( t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker), ) // the state transition is driven by processing of a CRYPTO frame hdr := &wire.ExtendedHeader{ Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1}, PacketNumberLen: protocol.PacketNumberLen2, } data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1) require.NoError(t, err) cs.EXPECT().DiscardInitialKeys().Times(2) gomock.InOrder( cs.EXPECT().StartHandshake(gomock.Any()), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil, ), cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), cs.EXPECT().SetHandshakeConfirmed(), cs.EXPECT().GetSessionTicket().Return([]byte("session ticket"), nil), ) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() p := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil) tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) select { case <-tc.conn.HandshakeComplete(): case <-tc.conn.Context().Done(): t.Fatal("connection context done") case <-time.After(time.Second): t.Fatal("timeout") } var foundSessionTicket, foundHandshakeDone, foundNewToken bool frames, _, _ := tc.conn.framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) for _, frame := range frames { switch f := frame.Frame.(type) { case *wire.CryptoFrame: assert.Equal(t, []byte("session ticket"), f.Data) foundSessionTicket = true case *wire.HandshakeDoneFrame: foundHandshakeDone = true case *wire.NewTokenFrame: assert.NotEmpty(t, f.Token) foundNewToken = true } } assert.True(t, foundSessionTicket) assert.True(t, foundHandshakeDone) assert.True(t, foundNewToken) // test teardown cs.EXPECT().Close() tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnectionHandshakeClient(t *testing.T) { t.Run("without preferred address", func(t *testing.T) { testConnectionHandshakeClient(t, false) }) t.Run("with preferred address", func(t *testing.T) { testConnectionHandshakeClient(t, true) }) } func testConnectionHandshakeClient(t *testing.T, usePreferredAddress bool) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := NewMockUnpacker(mockCtrl) tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker)) tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() // the state transition is driven by processing of a CRYPTO frame hdr := &wire.ExtendedHeader{ Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1}, PacketNumberLen: protocol.PacketNumberLen2, } data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1) require.NoError(t, err) tp := &wire.TransportParameters{ OriginalDestinationConnectionID: tc.destConnID, MaxIdleTimeout: time.Hour, } preferredAddressConnID := protocol.ParseConnectionID([]byte{10, 8, 6, 4}) preferredAddressResetToken := protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1} if usePreferredAddress { tp.PreferredAddress = &wire.PreferredAddress{ IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42), IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13), ConnectionID: preferredAddressConnID, StatelessResetToken: preferredAddressResetToken, } } packedFirstPacket := make(chan struct{}) gomock.InOrder( cs.EXPECT().StartHandshake(gomock.Any()), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn( func(b bool, bc protocol.ByteCount, t monotime.Time, v protocol.Version) (*coalescedPacket, error) { close(packedFirstPacket) return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil }, ), // initial keys are dropped when the first handshake packet is sent cs.EXPECT().DiscardInitialKeys(), // no more data to send unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil, ), cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: tp}), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), ) tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() select { case <-packedFirstPacket: case <-time.After(time.Second): t.Fatal("timeout") } p := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil) tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) select { case <-tc.conn.HandshakeComplete(): case <-tc.conn.Context().Done(): t.Fatal("connection context done") case <-time.After(time.Second): t.Fatal("timeout") } require.True(t, mockCtrl.Satisfied()) // the handshake isn't confirmed until we receive a HANDSHAKE_DONE frame from the server data, err = (&wire.HandshakeDoneFrame{}).Append(nil, protocol.Version1) require.NoError(t, err) done := make(chan struct{}) tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes() gomock.InOrder( unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{hdr: hdr, encryptionLevel: protocol.Encryption1RTT, data: data}, nil, ), cs.EXPECT().DiscardInitialKeys(), cs.EXPECT().SetHandshakeConfirmed(), tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { close(done) return shortHeaderPacket{}, errNothingToPack }, ), ) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() p = getLongHeaderPacket(t, tc.remoteAddr, hdr, nil) tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } if usePreferredAddress { tc.connRunner.EXPECT().AddResetToken(preferredAddressResetToken, gomock.Any()) } nextConnID := tc.conn.connIDManager.Get() if usePreferredAddress { require.Equal(t, preferredAddressConnID, nextConnID) } // test teardown cs.EXPECT().Close() tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() if usePreferredAddress { tc.connRunner.EXPECT().RemoveResetToken(preferredAddressResetToken) } tc.conn.destroy(nil) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnection0RTTTransportParameters(t *testing.T) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := NewMockUnpacker(mockCtrl) tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker)) tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() // the state transition is driven by processing of a CRYPTO frame hdr := &wire.ExtendedHeader{ Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1}, PacketNumberLen: protocol.PacketNumberLen2, } data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1) require.NoError(t, err) restored := &wire.TransportParameters{ ActiveConnectionIDLimit: 3, InitialMaxData: 0x5000, InitialMaxStreamDataBidiLocal: 0x5000, InitialMaxStreamDataBidiRemote: 1000, InitialMaxStreamDataUni: 1000, MaxBidiStreamNum: 500, MaxUniStreamNum: 500, } new := *restored new.MaxBidiStreamNum-- // the server is not allowed to reduce the limit new.OriginalDestinationConnectionID = tc.destConnID packedFirstPacket := make(chan struct{}) gomock.InOrder( cs.EXPECT().StartHandshake(gomock.Any()), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventRestoredTransportParameters, TransportParameters: restored}), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn( func(b bool, bc protocol.ByteCount, t monotime.Time, v protocol.Version) (*coalescedPacket, error) { close(packedFirstPacket) return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil }, ), // initial keys are dropped when the first handshake packet is sent cs.EXPECT().DiscardInitialKeys(), // no more data to send unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil, ), cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: &new}), cs.EXPECT().ConnectionState().Return(handshake.ConnectionState{Used0RTT: true}), // cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), cs.EXPECT().Close(), ) tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes() tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() select { case <-packedFirstPacket: case <-time.After(time.Second): t.Fatal("timeout") } p := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil) tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) select { case err := <-errChan: require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) require.ErrorContains(t, err, "server sent reduced limits after accepting 0-RTT data") case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnectionReceivePrioritization(t *testing.T) { for _, handshakeComplete := range []bool{true, false} { t.Run(fmt.Sprintf("handshake complete: %t", handshakeComplete), func(t *testing.T) { events := testConnectionReceivePrioritization(t, handshakeComplete, 5) require.Equal(t, []string{"unpack", "unpack", "unpack", "unpack", "unpack", "pack"}, events) }) } } func testConnectionReceivePrioritization(t *testing.T, handshakeComplete bool, numPackets int) []string { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) opts := []testConnectionOpt{connectionOptUnpacker(unpacker)} if handshakeComplete { opts = append(opts, connectionOptHandshakeConfirmed()) } tc := newServerTestConnection(t, mockCtrl, nil, false, opts...) var events []string var counter int var testDone bool done := make(chan struct{}) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn( func(rcvTime monotime.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { counter++ if counter == numPackets { testDone = true } events = append(events, "unpack") return protocol.PacketNumber(counter), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0, 1} /* PADDING, PING */, nil }, ).Times(numPackets) switch handshakeComplete { case false: tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(b bool, bc protocol.ByteCount, t monotime.Time, v protocol.Version) (*coalescedPacket, error) { events = append(events, "pack") if testDone { close(done) } return nil, nil }, ).AnyTimes() case true: tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(b *packetBuffer, bc protocol.ByteCount, t monotime.Time, v protocol.Version) (shortHeaderPacket, error) { events = append(events, "pack") if testDone { close(done) } return shortHeaderPacket{}, errNothingToPack }, ).AnyTimes() } for i := range numPackets { tc.conn.handlePacket(getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, protocol.PacketNumber(i), []byte("foobar"))) } tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } return events } func TestConnectionPacketBuffering(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) cs := mocks.NewMockCryptoSetup(mockCtrl) var eventRecorder events.Recorder tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptUnpacker(unpacker), connectionOptCryptoSetup(cs), connectionOptTracer(&eventRecorder), ) cs.EXPECT().DiscardInitialKeys() hdr1 := wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: tc.srcConnID, SrcConnectionID: tc.destConnID, Length: 8, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, } hdr2 := hdr1 hdr2.PacketNumber = 2 cs.EXPECT().StartHandshake(gomock.Any()) cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable).Times(2) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() hdrs := make(map[string]*wire.ExtendedHeader) packet1 := getLongHeaderPacket(t, tc.remoteAddr, &hdr1, []byte("packet1")) datagramID1 := qlog.CalculateDatagramID(packet1.data) hdrs["packet1"] = &hdr1 tc.conn.handlePacket(packet1) packet2 := getLongHeaderPacket(t, tc.remoteAddr, &hdr2, []byte("packet2")) datagramID2 := qlog.CalculateDatagramID(packet2.data) hdrs["packet2"] = &hdr2 tc.conn.handlePacket(packet2) synctest.Wait() require.Equal(t, []qlogwriter.Event{ qlog.PacketBuffered{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(packet1.Size())}, DatagramID: datagramID1, }, qlog.PacketBuffered{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(packet2.Size())}, DatagramID: datagramID2, }, }, eventRecorder.Events(qlog.PacketBuffered{}), ) eventRecorder.Clear() // Now send another packet. // In reality, this packet would contain a CRYPTO frame that advances the TLS handshake // such that new keys become available. var packets []string hdr3 := hdr1 hdr3.PacketNumber = 3 hdrs["packet3"] = &hdr3 tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys}) cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) gomock.InOrder( // packet 3 contains a CRYPTO frame and triggers the keys to become available unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { id := string(data[len(data)-7:]) packets = append(packets, id) cf := &wire.CryptoFrame{Data: []byte("foobar")} b, _ := cf.Append(nil, protocol.Version1) extHdr, ok := hdrs[id] if !ok { panic(fmt.Sprintf("unknown header: %v", id)) } return &unpackedPacket{hdr: extHdr, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil }, ), cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()), unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { id := string(data[len(data)-7:]) extHdr, ok := hdrs[id] if !ok { panic(fmt.Sprintf("unknown header: %v", id)) } packets = append(packets, id) return &unpackedPacket{hdr: extHdr, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil }, ).Times(2), ) packet3 := getLongHeaderPacket(t, tc.remoteAddr, &hdr3, []byte("packet3")) datagramID3 := qlog.CalculateDatagramID(packet3.data) tc.conn.handlePacket(packet3) synctest.Wait() // packet3 triggered the keys to become available // packet1 and packet2 are processed from the buffer in order require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets) require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, DestConnectionID: tc.srcConnID, SrcConnectionID: tc.destConnID, PacketNumber: 3, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(packet3.Size()), PayloadLength: 8}, DatagramID: datagramID3, Frames: []qlog.Frame{{Frame: &qlog.CryptoFrame{Length: 6}}}, }, qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, DestConnectionID: tc.srcConnID, SrcConnectionID: tc.destConnID, PacketNumber: 1, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(packet1.Size()), PayloadLength: 8}, DatagramID: datagramID1, Frames: []qlog.Frame{}, }, qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, DestConnectionID: tc.srcConnID, SrcConnectionID: tc.destConnID, PacketNumber: 2, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(packet1.Size()), PayloadLength: 8}, DatagramID: datagramID2, Frames: []qlog.Frame{}, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketBuffered{}), ) // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cs.EXPECT().Close() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } }) } func TestConnectionPacketPacing(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sender := NewMockSender(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptSentPacketHandler(sph), connectionOptSender(sender), connectionOptHandshakeConfirmed(), ) sender.EXPECT().Run() const step = 50 * time.Millisecond sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes() gomock.InOrder( // 1. allow 2 packets to be sent sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), // 2. become pacing limited for 25ms sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }), // 3. send another packet sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), // 4. become pacing limited for 25ms... sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }), // ... but this time we're still pacing limited when waking up. // In this case, we can only send an ACK. sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), // 5. stop the test by becoming pacing limited forever sph.EXPECT().TimeUntilSend().Return(monotime.Now().Add(time.Hour)), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), ) sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() for i := range 3 { tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn( func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...) return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil }, ) } tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(_ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) { buf := getPacketBuffer() buf.Data = []byte("ack") return shortHeaderPacket{PacketNumber: 1}, buf, nil }, ) sender.EXPECT().WouldBlock().AnyTimes() type sentPacket struct { time monotime.Time data []byte } sendChan := make(chan sentPacket, 10) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) { sendChan <- sentPacket{time: monotime.Now(), data: b.Data} }).Times(4) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() var times []monotime.Time for i := range 3 { select { case b := <-sendChan: require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data) times = append(times, b.time) case <-time.After(time.Hour): t.Fatal("should have sent a packet") } } select { case b := <-sendChan: require.Equal(t, []byte("ack"), b.data) times = append(times, b.time) case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, times[0], times[1]) require.Equal(t, times[2], times[1].Add(step)) require.Equal(t, times[3], times[2].Add(step)) synctest.Wait() // make sure that no more packets are sent require.True(t, mockCtrl.Satisfied()) // test teardown sender.EXPECT().Close() tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case <-sendChan: t.Fatal("should not have sent any more packets") case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } }) } // When the send queue blocks, we need to reset the pacing timer, otherwise the run loop might busy-loop. // See https://github.com/quic-go/quic-go/pull/4943 for more details. func TestConnectionPacingAndSendQueue(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sender := NewMockSender(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptSentPacketHandler(sph), connectionOptSender(sender), connectionOptHandshakeConfirmed(), ) sender.EXPECT().Run() sendQueueAvailable := make(chan struct{}) pacingDeadline := monotime.Now().Add(-time.Millisecond) var counter int // allow exactly one packet to be sent, then become blocked sender.EXPECT().WouldBlock().Return(false) sender.EXPECT().WouldBlock().DoAndReturn(func() bool { counter++; return true }).AnyTimes() sender.EXPECT().Available().Return(sendQueueAvailable).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited).AnyTimes() sph.EXPECT().TimeUntilSend().Return(pacingDeadline).AnyTimes() sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNNon).AnyTimes() tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{}, nil, errNothingToPack, ) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() sender.EXPECT().Close() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } // make sure the run loop didn't do too many iterations require.Less(t, counter, 3) }) } func TestConnectionIdleTimeout(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, &Config{MaxIdleTimeout: time.Minute}, false, connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), connectionOptRTT(time.Millisecond), ) // the idle timeout is set when the transport parameters are received const idleTimeout = 500 * time.Millisecond require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ MaxIdleTimeout: idleTimeout, })) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() var lastSendTime monotime.Time tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { buf.Data = append(buf.Data, []byte("foobar")...) lastSendTime = monotime.Now() return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil }, ) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, &IdleTimeoutError{}) require.NotZero(t, lastSendTime) require.Equal(t, idleTimeout, monotime.Since(lastSendTime)) case <-time.After(time.Hour): t.Fatal("should have timed out") } }) } func TestConnectionKeepAlive(t *testing.T) { t.Run("enabled", func(t *testing.T) { testConnectionKeepAlive(t, true, true) }) t.Run("disabled", func(t *testing.T) { testConnectionKeepAlive(t, false, false) }) } func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) { synctest.Test(t, func(t *testing.T) { var keepAlivePeriod time.Duration if enable { keepAlivePeriod = time.Second } mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) tc := newServerTestConnection(t, mockCtrl, &Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod}, false, connectionOptUnpacker(unpacker), connectionOptHandshakeConfirmed(), connectionOptRTT(time.Millisecond), ) // the idle timeout is set when the transport parameters are received const idleTimeout = 50 * time.Millisecond require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ MaxIdleTimeout: idleTimeout, })) // Receive a packet. This starts the keep-alive timer. buf := getPacketBuffer() var err error buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero) require.NoError(t, err) buf.Data = append(buf.Data, []byte("packet")...) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() var unpackTime, packTime monotime.Time done := make(chan struct{}) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn( func(t monotime.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { unpackTime = monotime.Now() return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil }, ) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) switch expectKeepAlive { case true: // record the time of the keep-alive is sent tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { packTime = monotime.Now() close(done) return shortHeaderPacket{}, errNothingToPack }, ) tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) select { case <-done: // the keep-alive packet should be sent after half the idle timeout require.Equal(t, unpackTime.Add(idleTimeout/2), packTime) case <-time.After(idleTimeout): t.Fatal("timeout") } case false: // if keep-alives are disabled, the connection will run into an idle timeout tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr}) } // test teardown if expectKeepAlive { tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) } synctest.Wait() select { case err := <-errChan: if expectKeepAlive { require.NoError(t, err) } else { require.ErrorIs(t, err, &IdleTimeoutError{}) } case <-time.After(time.Hour): t.Fatal("timeout") } }) } func TestConnectionACKTimer(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, &Config{MaxIdleTimeout: time.Second}, false, connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), ) const alarmTimeout = 500 * time.Millisecond sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() // Set initial alarm timeout far in the future _ = tc.receivedPacketHandler().ReceivedPacket(1, protocol.ECNNon, protocol.Encryption1RTT, monotime.Now().Add(time.Hour), true) var times []monotime.Time done := make(chan struct{}, 5) var calls []any for range 2 { calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) { buf.Data = append(buf.Data, []byte("foobar")...) times = append(times, monotime.Now()) rph := tc.receivedPacketHandler() if len(times) == 1 { // After first packet is sent, set alarm timeout for the next iteration // Get the ACK frame to reset state, then receive a new packet to set alarm _ = rph.GetAckFrame(protocol.Encryption1RTT, monotime.Now(), false) alarmRcvTime := monotime.Now().Add(alarmTimeout - protocol.MaxAckDelay) _ = rph.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, alarmRcvTime, true) } else { // After second packet is sent, set alarm timeout far in the future _ = rph.GetAckFrame(protocol.Encryption1RTT, monotime.Now(), false) _ = rph.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, monotime.Now().Add(time.Hour), true) } return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil }, )) calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) { done <- struct{}{} return shortHeaderPacket{}, errNothingToPack }, )) } gomock.InOrder(calls...) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() for range 2 { synctest.Wait() select { case <-done: case <-time.After(time.Hour): t.Fatal("timeout") } } assert.Len(t, times, 2) require.Equal(t, times[0].Add(alarmTimeout), times[1]) // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } }) } // Send a GSO batch, until we have no more data to send. func TestConnectionGSOBatch(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, true, connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), ) // allow packets to be sent sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() maxPacketSize := tc.conn.maxPacketSize() var expectedData []byte for i := range 4 { data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) expectedData = append(expectedData, data...) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { buffer.Data = append(buffer.Data, data...) return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil }, ) } done := make(chan struct{}) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn( func([]byte, uint16, protocol.ECN) error { close(done); return nil }, ) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() select { case <-done: default: t.Fatal("should have sent a packet") } // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } }) } // Send a GSO batch, until a packet smaller than the maximum size is packed func TestConnectionGSOBatchPacketSize(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, true, connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), ) // allow packets to be sent sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() maxPacketSize := tc.conn.maxPacketSize() var expectedData []byte var calls []any for i := range 4 { var data []byte if i == 3 { data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1)) } else { data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) } expectedData = append(expectedData, data...) calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { buffer.Data = append(buffer.Data, data...) return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil }, )) } // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch. // We therefore send a "foobar", so we can check that we're actually generating two GSO batches. calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { buffer.Data = append(buffer.Data, []byte("foobar")...) return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil }, ), ) calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack), ) gomock.InOrder(calls...) done := make(chan struct{}) gomock.InOrder( tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1), tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn( func([]byte, uint16, protocol.ECN) error { close(done); return nil }, ), ) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() select { case <-done: default: t.Fatal("should have sent a packet") } // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } }) } func TestConnectionGSOBatchECN(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, true, connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), ) // allow packets to be sent ecnMode := protocol.ECT1 sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes() // 3. Send a GSO batch, until the ECN marking changes. var expectedData []byte var calls []any maxPacketSize := tc.conn.maxPacketSize() for i := range 3 { data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize)) expectedData = append(expectedData, data...) calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { buffer.Data = append(buffer.Data, data...) if i == 2 { ecnMode = protocol.ECNCE } return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil }, )) } // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch. // We therefore send a "foobar", so we can check that we're actually generating two GSO batches. calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { buffer.Data = append(buffer.Data, []byte("foobar")...) return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil }, ), ) calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack), ) gomock.InOrder(calls...) done3 := make(chan struct{}) tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1) tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn( func([]byte, uint16, protocol.ECN) error { close(done3); return nil }, ) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() select { case <-done3: default: t.Fatal("should have sent a packet") } // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } }) } func TestConnectionPTOProbePackets(t *testing.T) { t.Run("Initial", func(t *testing.T) { testConnectionPTOProbePackets(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testConnectionPTOProbePackets(t, protocol.EncryptionHandshake) }) t.Run("1-RTT", func(t *testing.T) { testConnectionPTOProbePackets(t, protocol.Encryption1RTT) }) } func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptSentPacketHandler(sph), ) var sendMode ackhandler.SendMode switch encLevel { case protocol.EncryptionInitial: sendMode = ackhandler.SendPTOInitial case protocol.EncryptionHandshake: sendMode = ackhandler.SendPTOHandshake case protocol.Encryption1RTT: sendMode = ackhandler.SendPTOAppData } sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().ECNMode(gomock.Any()) sph.EXPECT().QueueProbePacket(encLevel).Return(false) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tc.packer.EXPECT().PackPTOProbePacket(encLevel, gomock.Any(), true, gomock.Any(), protocol.Version1).DoAndReturn( func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error) { return &coalescedPacket{ buffer: getPacketBuffer(), shortHdrPacket: &shortHeaderPacket{PacketNumber: 1}, }, nil }, ) done := make(chan struct{}) tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( func([]byte, uint16, protocol.ECN) error { close(done); return nil }, ) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have timed out") } }) } func TestConnectionCongestionControl(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), ) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().ECNMode(true).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) // Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket for i := range 2 { tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) { buffer.Data = append(buffer.Data, []byte("foobar")...) return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil }, ) } tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) done1 := make(chan struct{}) tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( func([]byte, uint16, protocol.ECN) error { close(done1); return nil }, ) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() select { case <-done1: default: t.Fatal("should have sent a packet") } require.True(t, mockCtrl.Satisfied()) // Now that we're congestion limited, we can only send an ack-only packet done2 := make(chan struct{}) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) { close(done2) return shortHeaderPacket{}, nil, errNothingToPack }, ) tc.conn.scheduleSending() synctest.Wait() select { case <-done2: default: t.Fatal("should have sent an ack-only packet") } require.True(t, mockCtrl.Satisfied()) // If the send mode is "none", we can't even send an ack-only packet sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) tc.conn.scheduleSending() synctest.Wait() // make sure there are no calls to the packer // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("timeout") } }) } func TestConnectionSendQueue(t *testing.T) { t.Run("with GSO", func(t *testing.T) { testConnectionSendQueue(t, true) }) t.Run("without GSO", func(t *testing.T) { testConnectionSendQueue(t, false) }) } func testConnectionSendQueue(t *testing.T, enableGSO bool) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sender := NewMockSender(mockCtrl) tc := newServerTestConnection(t, mockCtrl, nil, enableGSO, connectionOptSender(sender), connectionOptHandshakeConfirmed(), connectionOptSentPacketHandler(sph), ) sender.EXPECT().Run().MaxTimes(1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) available := make(chan struct{}) blocked := make(chan struct{}) sender.EXPECT().Available().DoAndReturn( func() <-chan struct{} { close(blocked) return available }, ) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().ECNMode(gomock.Any()).AnyTimes() tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil, ) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() tc.conn.scheduleSending() synctest.Wait() select { case <-blocked: default: t.Fatal("should have blocked") } require.True(t, mockCtrl.Satisfied()) // now make room in the send queue sender.EXPECT().WouldBlock().AnyTimes() unblocked := make(chan struct{}) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) { close(unblocked) return shortHeaderPacket{}, errNothingToPack }, ) available <- struct{}{} synctest.Wait() select { case <-unblocked: default: t.Fatal("should have unblocked") } // test teardown sender.EXPECT().Close() tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("timeout") } }) } func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket { b := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(src.Bytes()), protocol.ArbitraryLenConnectionID(dest.Bytes()), versions, ) return receivedPacket{ rcvTime: monotime.Now(), data: b, buffer: getPacketBuffer(), } } func TestConnectionVersionNegotiation(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder)) tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() tc.connRunner.EXPECT().Remove(gomock.Any()) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() vnp := getVersionNegotiationPacket( tc.destConnID, tc.srcConnID, []protocol.Version{1234, protocol.Version2}, ) // the version negotiation packet might contained greased versions _, _, vnpVersions, err := wire.ParseVersionNegotiationPacket(vnp.data) require.NoError(t, err) tc.conn.handlePacket(vnp) synctest.Wait() select { case err := <-errChan: var rerr *errCloseForRecreating require.ErrorAs(t, err, &rerr) require.Equal(t, rerr.nextVersion, protocol.Version2) default: t.Fatal("should have received a Version Negotiation packet") } require.Equal(t, []qlogwriter.Event{ qlog.VersionNegotiationReceived{ Header: qlog.PacketHeaderVersionNegotiation{ SrcConnectionID: protocol.ArbitraryLenConnectionID(tc.destConnID.Bytes()), DestConnectionID: protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()), }, SupportedVersions: vnpVersions, }, qlog.VersionInformation{ ServerVersions: vnpVersions, ClientVersions: []qlog.Version{protocol.Version1, protocol.Version2}, ChosenVersion: protocol.Version2, }, }, eventRecorder.Events(qlog.VersionNegotiationReceived{}, qlog.VersionInformation{}), ) }) } func TestConnectionVersionNegotiationNoMatch(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newClientTestConnection(t, mockCtrl, &Config{Versions: []protocol.Version{protocol.Version1}}, false, connectionOptTracer(&eventRecorder), ) tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() tc.connRunner.EXPECT().Remove(gomock.Any()) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() vnp := getVersionNegotiationPacket( tc.destConnID, tc.srcConnID, []protocol.Version{protocol.Version2}, ) _, _, vnpVersions, err := wire.ParseVersionNegotiationPacket(vnp.data) require.NoError(t, err) tc.conn.handlePacket(vnp) synctest.Wait() select { case err := <-errChan: var verr *VersionNegotiationError require.ErrorAs(t, err, &verr) require.Contains(t, verr.Theirs, protocol.Version2) require.Equal(t, []qlogwriter.Event{ qlog.VersionNegotiationReceived{ Header: qlog.PacketHeaderVersionNegotiation{ SrcConnectionID: protocol.ArbitraryLenConnectionID(tc.destConnID.Bytes()), DestConnectionID: protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()), }, SupportedVersions: vnpVersions, }, qlog.ConnectionClosed{ Initiator: qlog.InitiatorLocal, Trigger: qlog.ConnectionCloseTriggerVersionMismatch, }, }, eventRecorder.Events(qlog.VersionNegotiationReceived{}, qlog.ConnectionClosed{}), ) default: t.Fatal("should have received a Version Negotiation packet") } }) } func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), ) // offers the current version vnp := getVersionNegotiationPacket( tc.destConnID, tc.srcConnID, []protocol.Version{1234, protocol.Version1}, ) wasProcessed, err := tc.conn.handleOnePacket(vnp, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: int(vnp.Size())}, Trigger: qlog.PacketDropUnexpectedVersion, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) require.True(t, mockCtrl.Satisfied()) eventRecorder.Clear() // unparseable, since it's missing 2 bytes vnp.data = vnp.data[:len(vnp.data)-2] wasProcessed, err = tc.conn.handleOnePacket(vnp, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: int(vnp.Size())}, Trigger: qlog.PacketDropHeaderParseError, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) } func getRetryPacket(t *testing.T, src, dest, origDest protocol.ConnectionID, token []byte) receivedPacket { hdr := wire.Header{ Type: protocol.PacketTypeRetry, SrcConnectionID: src, DestConnectionID: dest, Token: token, Version: protocol.Version1, } b, err := (&wire.ExtendedHeader{Header: hdr}).Append(nil, protocol.Version1) require.NoError(t, err) tag := handshake.GetRetryIntegrityTag(b, origDest, protocol.Version1) b = append(b, tag[:]...) return receivedPacket{ rcvTime: monotime.Now(), data: b, buffer: getPacketBuffer(), } } func TestConnectionRetryDrops(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder unpacker := NewMockUnpacker(mockCtrl) tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), connectionOptUnpacker(unpacker), ) newConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) // invalid integrity tag retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar")) retry.data[len(retry.data)-1]++ wasProcessed, err := tc.conn.handleOnePacket(retry, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: newConnID, DestConnectionID: tc.srcConnID, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(retry.Size())}, Trigger: qlog.PacketDropPayloadDecryptError, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) eventRecorder.Clear() // receive a retry that doesn't change the connection ID retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar")) wasProcessed, err = tc.conn.handleOnePacket(retry, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: tc.destConnID, DestConnectionID: tc.srcConnID, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(retry.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) } func TestConnectionRetryAfterReceivedPacket(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder unpacker := NewMockUnpacker(mockCtrl) tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), connectionOptUnpacker(unpacker), ) // receive a regular packet regular := getPacketWithPacketType(t, tc.srcConnID, protocol.PacketTypeInitial, 200) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{ hdr: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeInitial}}, encryptionLevel: protocol.EncryptionInitial, }, nil, ) wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{ data: regular, buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr, }, 0) require.NoError(t, err) require.True(t, wasProcessed) require.Len(t, eventRecorder.Events(qlog.PacketReceived{}), 1) require.Equal(t, []qlogwriter.Event{ qlog.VersionInformation{ ChosenVersion: protocol.Version1, ClientVersions: tc.conn.config.Versions, }, }, eventRecorder.Events(qlog.VersionInformation{}), ) eventRecorder.Clear() // receive a retry retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar")) wasProcessed, err = tc.conn.handleOnePacket(retry, 0) require.NoError(t, err) require.False(t, wasProcessed) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: tc.conn.origDestConnID, DestConnectionID: tc.srcConnID, Version: tc.conn.version, }, Raw: qlog.RawInfo{Length: int(retry.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) eventRecorder.Clear() } func TestConnectionConnectionIDChanges(t *testing.T) { t.Run("with retry", func(t *testing.T) { testConnectionConnectionIDChanges(t, true) }) t.Run("without retry", func(t *testing.T) { testConnectionConnectionIDChanges(t, false) }) } func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) { synctest.Test(t, func(t *testing.T) { makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte { t.Helper() data, err := hdr.Append(nil, protocol.Version1) require.NoError(t, err) data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...) return data } mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder unpacker := NewMockUnpacker(mockCtrl) tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), connectionOptUnpacker(unpacker), ) dstConnID := tc.destConnID b := make([]byte, 3*10) rand.Read(b) newConnID := protocol.ParseConnectionID(b[:11]) newConnID2 := protocol.ParseConnectionID(b[11:20]) tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() require.Equal(t, dstConnID, tc.conn.connIDManager.Get()) var retryConnID protocol.ConnectionID if sendRetry { retryConnID = protocol.ParseConnectionID(b[20:30]) tc.packer.EXPECT().SetToken([]byte("foobar")) retry := getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")) tc.conn.handlePacket(retry) synctest.Wait() require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: retryConnID, DestConnectionID: dstConnID, Version: protocol.Version1, Token: &qlog.Token{Raw: []byte("foobar")}, }, Raw: qlog.RawInfo{Length: int(retry.Size())}, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketDropped{}), ) } eventRecorder.Clear() // Send the first packet. The server changes the connection ID to newConnID. hdr1 := wire.ExtendedHeader{ Header: wire.Header{ SrcConnectionID: newConnID, DestConnectionID: tc.srcConnID, Type: protocol.PacketTypeInitial, Length: 200, Version: protocol.Version1, }, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, } hdr2 := hdr1 hdr2.SrcConnectionID = newConnID2 unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionInitial}, nil, ) eventRecorder.Clear() packet1 := getLongHeaderPacket(t, tc.remoteAddr, &hdr1, make([]byte, 198)) tc.conn.handlePacket(packet1) synctest.Wait() require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, SrcConnectionID: newConnID, DestConnectionID: tc.srcConnID, PacketNumber: 1, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(packet1.Size()), PayloadLength: int(hdr1.Length)}, DatagramID: qlog.CalculateDatagramID(packet1.data), Frames: []qlog.Frame{}, }, }, eventRecorder.Events(qlog.PacketReceived{}, qlog.PacketDropped{}), ) eventRecorder.Clear() // Send the second packet. We refuse to accept it, because the connection ID is changed again. packet2 := receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr} tc.conn.handlePacket(packet2) synctest.Wait() require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(packet2.Size())}, DatagramID: qlog.CalculateDatagramID(packet2.data), Trigger: qlog.PacketDropUnknownConnectionID, }, }, eventRecorder.Events(qlog.PacketDropped{}, qlog.PacketReceived{}), ) // the connection ID should not have changed require.Equal(t, newConnID, tc.conn.connIDManager.Get()) // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()) tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have shut down") } }) } // When the connection is closed before sending the first packet, // we don't send a CONNECTION_CLOSE. // This can happen if there's something wrong the tls.Config, and // crypto/tls refuses to start the handshake. func TestConnectionEarlyClose(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) var eventRecorder events.Recorder cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl) tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(&eventRecorder), connectionOptCryptoSetup(cryptoSetup), ) tc.conn.sentFirstPacket = false cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error { tc.conn.closeLocal(errors.New("early error")) return nil }) cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().Close() tc.connRunner.EXPECT().Remove(gomock.Any()) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() synctest.Wait() select { case err := <-errChan: require.Error(t, err) require.ErrorContains(t, err, "early error") code := qerr.InternalError require.Equal(t, []qlogwriter.Event{ qlog.ConnectionClosed{ Initiator: qlog.InitiatorLocal, ConnectionError: &code, Reason: "early error", }, }, eventRecorder.Events(qlog.ConnectionClosed{}), ) default: t.Fatal("should have shut down") } }) } func TestConnectionPathValidation(t *testing.T) { t.Run("NAT rebinding", func(t *testing.T) { testConnectionPathValidation(t, true) }) t.Run("intentional migration", func(t *testing.T) { testConnectionPathValidation(t, false) }) } func testConnectionPathValidation(t *testing.T, isNATRebinding bool) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) unpacker := NewMockUnpacker(mockCtrl) tc := newServerTestConnection( t, mockCtrl, nil, false, connectionOptUnpacker(unpacker), connectionOptHandshakeConfirmed(), connectionOptRTT(time.Second), ) require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{MaxUDPPayloadSize: 1456})) newRemoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234} require.NotEqual(t, tc.remoteAddr, newRemoteAddr) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() probeSent := make(chan struct{}) var pathChallenge *wire.PathChallengeFrame payload := []byte{0} // PADDING frame if isNATRebinding { payload = []byte{1} // PING frame } gomock.InOrder( unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil, ), tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) { pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame) return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil }, ), tc.sendConn.EXPECT().WriteTo(gomock.Any(), newRemoteAddr).DoAndReturn( func([]byte, net.Addr) error { close(probeSent); return nil }, ), tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{}, errNothingToPack, ), ) tc.conn.handlePacket(receivedPacket{ data: make([]byte, 10), buffer: getPacketBuffer(), remoteAddr: newRemoteAddr, rcvTime: monotime.Now(), }) synctest.Wait() select { case <-probeSent: case <-time.After(time.Second): t.Fatal("timeout") } // Receive a packed containing a PATH_RESPONSE frame. // Only if the first packet received on the path was a probing packet // (i.e. we're dealing with a NAT rebinding), this makes us switch to the new path. migrated := make(chan struct{}) data, err := (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(nil, protocol.Version1) require.NoError(t, err) calls := []any{ unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( protocol.PacketNumber(11), protocol.PacketNumberLen2, protocol.KeyPhaseZero, data, nil, ), } if isNATRebinding { calls = append(calls, tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do( func(net.Addr, packetInfo) { close(migrated) }, ), ) } calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{}, errNothingToPack, ).MaxTimes(1), ) gomock.InOrder(calls...) require.Equal(t, tc.remoteAddr, tc.conn.RemoteAddr()) // the PATH_RESPONSE can be sent on the old path, if the client is just probing the new path addr := tc.remoteAddr if isNATRebinding { addr = newRemoteAddr } tc.conn.handlePacket(receivedPacket{ data: make([]byte, 100), buffer: getPacketBuffer(), remoteAddr: addr, rcvTime: monotime.Now(), }) synctest.Wait() if !isNATRebinding { // If the first packet was a probing packet, we only switch to the new path when we // receive a non-probing packet on that path. select { case <-migrated: t.Fatal("didn't expect a migration yet") default: } payload := []byte{1} // PING frame payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1) require.NoError(t, err) gomock.InOrder( unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return( protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil, ), tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do( func(net.Addr, packetInfo) { close(migrated) }, ), tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{}, errNothingToPack, ).MaxTimes(1), ) tc.conn.handlePacket(receivedPacket{ data: make([]byte, 100), buffer: getPacketBuffer(), remoteAddr: newRemoteAddr, rcvTime: monotime.Now(), }) } synctest.Wait() select { case <-migrated: default: t.Fatal("should have migrated") } // test teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.conn.destroy(nil) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("should have shut down") } }) } func TestConnectionMigrationServer(t *testing.T) { tc := newServerTestConnection(t, nil, nil, false) _, err := tc.conn.AddPath(&Transport{}) require.Error(t, err) require.ErrorContains(t, err, "server cannot initiate connection migration") } func TestConnectionMigration(t *testing.T) { t.Run("disabled", func(t *testing.T) { testConnectionMigration(t, false) }) t.Run("enabled", func(t *testing.T) { testConnectionMigration(t, true) }) } func testConnectionMigration(t *testing.T, enabled bool) { tc := newClientTestConnection(t, nil, nil, false, connectionOptHandshakeConfirmed()) require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{ InitialSourceConnectionID: tc.destConnID, OriginalDestinationConnectionID: tc.destConnID, DisableActiveMigration: !enabled, })) tr := &Transport{ Conn: newUDPConnLocalhost(t), StatelessResetKey: &StatelessResetKey{}, } defer tr.Close() path, err := tc.conn.AddPath(tr) if !enabled { require.Error(t, err) require.ErrorContains(t, err, "server disabled connection migration") return } require.NoError(t, err) require.NotNil(t, path) tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( shortHeaderPacket{}, errNothingToPack, ).AnyTimes() packedProbe := make(chan struct{}) tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error) { defer close(packedProbe) return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil }, ).AnyTimes() tc.connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) // add a new connection ID, so the path can be probed _, err = tc.conn.handleFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), }, protocol.EncryptionInitial, tc.destConnID, monotime.Now()) require.NoError(t, err) errChan := make(chan error, 1) go func() { errChan <- tc.conn.run() }() // Adding the path initialized the transport. // We can test this by triggering a stateless reset. conn := newUDPConnLocalhost(t) _, err = conn.WriteTo(append([]byte{0x40}, make([]byte, 100)...), tr.Conn.LocalAddr()) require.NoError(t, err) conn.SetReadDeadline(time.Now().Add(time.Second)) _, _, err = conn.ReadFrom(make([]byte, 100)) require.NoError(t, err) go func() { path.Probe(context.Background()) }() select { case <-packedProbe: case <-time.After(time.Second): t.Fatal("timeout") } // teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tc.connRunner.EXPECT().RemoveResetToken(gomock.Any()).MaxTimes(1) tc.conn.destroy(nil) select { case <-errChan: case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnectionDatagrams(t *testing.T) { t.Run("disabled", func(t *testing.T) { testConnectionDatagrams(t, false) }) t.Run("enabled", func(t *testing.T) { testConnectionDatagrams(t, true) }) } func testConnectionDatagrams(t *testing.T, enabled bool) { tc := newServerTestConnection(t, nil, &Config{EnableDatagrams: enabled}, false) data, err := (&wire.DatagramFrame{Data: []byte("foo"), DataLenPresent: true}).Append(nil, protocol.Version1) require.NoError(t, err) data, err = (&wire.DatagramFrame{Data: []byte("bar")}).Append(data, protocol.Version1) require.NoError(t, err) _, _, _, err = tc.conn.handleFrames(data, protocol.ConnectionID{}, protocol.Encryption1RTT, nil, monotime.Now()) if !enabled { require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.FrameEncodingError, FrameType: uint64(wire.FrameTypeDatagramWithLength)}) return } require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() d, err := tc.conn.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, []byte("foo"), d) d, err = tc.conn.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, []byte("bar"), d) } quic-go-0.59.0/crypto_stream.go000066400000000000000000000150001513066070600163730ustar00rootroot00000000000000package quic import ( "errors" "fmt" "io" "os" "slices" "strconv" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) const disableClientHelloScramblingEnv = "QUIC_GO_DISABLE_CLIENTHELLO_SCRAMBLING" // The baseCryptoStream is used by the cryptoStream and the initialCryptoStream. // This allows us to implement different logic for PopCryptoFrame for the two streams. type baseCryptoStream struct { queue frameSorter highestOffset protocol.ByteCount finished bool writeOffset protocol.ByteCount writeBuf []byte } func newCryptoStream() *cryptoStream { return &cryptoStream{baseCryptoStream{queue: *newFrameSorter()}} } func (s *baseCryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error { highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { return &qerr.TransportError{ ErrorCode: qerr.CryptoBufferExceeded, ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset), } } if s.finished { if highestOffset > s.highestOffset { // reject crypto data received after this stream was already finished return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received crypto data after change of encryption level", } } // ignore data with a smaller offset than the highest received // could e.g. be a retransmission return nil } s.highestOffset = max(s.highestOffset, highestOffset) return s.queue.Push(f.Data, f.Offset, nil) } // GetCryptoData retrieves data that was received in CRYPTO frames func (s *baseCryptoStream) GetCryptoData() []byte { _, data, _ := s.queue.Pop() return data } func (s *baseCryptoStream) Finish() error { if s.queue.HasMoreData() { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "encryption level changed, but crypto stream has more data to read", } } s.finished = true return nil } // Writes writes data that should be sent out in CRYPTO frames func (s *baseCryptoStream) Write(p []byte) (int, error) { s.writeBuf = append(s.writeBuf, p...) return len(p), nil } func (s *baseCryptoStream) HasData() bool { return len(s.writeBuf) > 0 } func (s *baseCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { f := &wire.CryptoFrame{Offset: s.writeOffset} n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) if n <= 0 { return nil } f.Data = s.writeBuf[:n] s.writeBuf = s.writeBuf[n:] s.writeOffset += n return f } type cryptoStream struct { baseCryptoStream } type clientHelloCut struct { start protocol.ByteCount end protocol.ByteCount } type initialCryptoStream struct { baseCryptoStream scramble bool end protocol.ByteCount cuts [2]clientHelloCut } func newInitialCryptoStream(isClient bool) *initialCryptoStream { var scramble bool if isClient { disabled, err := strconv.ParseBool(os.Getenv(disableClientHelloScramblingEnv)) scramble = err != nil || !disabled } s := &initialCryptoStream{ baseCryptoStream: baseCryptoStream{queue: *newFrameSorter()}, scramble: scramble, } for i := range len(s.cuts) { s.cuts[i].start = protocol.InvalidByteCount s.cuts[i].end = protocol.InvalidByteCount } return s } func (s *initialCryptoStream) HasData() bool { // The ClientHello might be written in multiple parts. // In order to correctly split the ClientHello, we need the entire ClientHello has been queued. if s.scramble && s.writeOffset == 0 && s.cuts[0].start == protocol.InvalidByteCount { return false } return s.baseCryptoStream.HasData() } func (s *initialCryptoStream) Write(p []byte) (int, error) { s.writeBuf = append(s.writeBuf, p...) if !s.scramble { return len(p), nil } if s.cuts[0].start == protocol.InvalidByteCount { sniPos, sniLen, echPos, err := findSNIAndECH(s.writeBuf) if errors.Is(err, io.ErrUnexpectedEOF) { return len(p), nil } if err != nil { return len(p), err } if sniPos == -1 && echPos == -1 { // Neither SNI nor ECH found. // There's nothing to scramble. s.scramble = false return len(p), nil } s.end = protocol.ByteCount(len(s.writeBuf)) s.cuts[0].start = protocol.ByteCount(sniPos + sniLen/2) // right in the middle s.cuts[0].end = protocol.ByteCount(sniPos + sniLen) if echPos > 0 { // ECH extension found, cut the ECH extension type value (a uint16) in half start := protocol.ByteCount(echPos + 1) s.cuts[1].start = start // cut somewhere (16 bytes), most likely in the ECH extension value s.cuts[1].end = min(start+16, s.end) } slices.SortFunc(s.cuts[:], func(a, b clientHelloCut) int { if a.start == protocol.InvalidByteCount { return 1 } if a.start > b.start { return 1 } return -1 }) } return len(p), nil } func (s *initialCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { if !s.scramble { return s.baseCryptoStream.PopCryptoFrame(maxLen) } // send out the skipped parts if s.writeOffset == s.end { var foundCuts bool var f *wire.CryptoFrame for i, c := range s.cuts { if c.start == protocol.InvalidByteCount { continue } foundCuts = true if f != nil { break } f = &wire.CryptoFrame{Offset: c.start} n := min(f.MaxDataLen(maxLen), c.end-c.start) if n <= 0 { return nil } f.Data = s.writeBuf[c.start : c.start+n] s.cuts[i].start += n if s.cuts[i].start == c.end { s.cuts[i].start = protocol.InvalidByteCount s.cuts[i].end = protocol.InvalidByteCount foundCuts = false } } if !foundCuts { // no more cuts found, we're done sending out everything up until s.end s.writeBuf = s.writeBuf[s.end:] s.end = protocol.InvalidByteCount s.scramble = false } return f } nextCut := clientHelloCut{start: protocol.InvalidByteCount, end: protocol.InvalidByteCount} for _, c := range s.cuts { if c.start == protocol.InvalidByteCount { continue } if c.start > s.writeOffset { nextCut = c break } } f := &wire.CryptoFrame{Offset: s.writeOffset} maxOffset := nextCut.start if maxOffset == protocol.InvalidByteCount { maxOffset = s.end } n := min(f.MaxDataLen(maxLen), maxOffset-s.writeOffset) if n <= 0 { return nil } f.Data = s.writeBuf[s.writeOffset : s.writeOffset+n] // Don't reslice the writeBuf yet. // This is done once all parts have been sent out. s.writeOffset += n if s.writeOffset == nextCut.start { s.writeOffset = nextCut.end } return f } quic-go-0.59.0/crypto_stream_go124_test.go000066400000000000000000000052361513066070600203600ustar00rootroot00000000000000//go:build go1.24 package quic import ( "fmt" mrand "math/rand/v2" "slices" "strings" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func randomDomainName(length int) string { const alphabet = "abcdefghijklmnopqrstuvwxyz" b := make([]byte, length) for i := range b { if i > 0 && i < length-1 && mrand.IntN(5) == 0 && b[i-1] != '.' { b[i] = '.' } else { b[i] = alphabet[mrand.IntN(len(alphabet))] } } return string(b) } func TestInitialCryptoStreamClientRandomizedSizes(t *testing.T) { skipIfDisableScramblingEnvSet(t) for i := range 100 { t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { var serverName string if mrand.Int()%4 > 0 { serverName = randomDomainName(6 + mrand.IntN(20)) } var clientHello []byte if serverName == "" || !strings.Contains(serverName, ".") || mrand.Int()%2 == 0 { t.Logf("using a ClientHello without ECH, hostname: %q", serverName) clientHello = getClientHello(t, serverName) } else { t.Logf("using a ClientHello with ECH, hostname: %q", serverName) clientHello = getClientHelloWithECH(t, serverName) } testInitialCryptoStreamClientRandomizedSizes(t, clientHello, serverName) }) } } func testInitialCryptoStreamClientRandomizedSizes(t *testing.T, clientHello []byte, expectedServerName string) { str := newInitialCryptoStream(true) b := slices.Clone(clientHello) for len(b) > 0 { n := min(len(b), mrand.IntN(2*len(b))) _, err := str.Write(b[:n]) require.NoError(t, err) b = b[n:] } require.True(t, str.HasData()) _, err := str.Write([]byte("foobar")) require.NoError(t, err) segments := make(map[protocol.ByteCount][]byte) var frames []*wire.CryptoFrame for str.HasData() { // fmt.Println("popping a frame") var maxSize protocol.ByteCount if mrand.Int()%4 == 0 { maxSize = protocol.ByteCount(mrand.IntN(512) + 1) } else { maxSize = protocol.ByteCount(mrand.IntN(32) + 1) } f := str.PopCryptoFrame(maxSize) if f == nil { continue } frames = append(frames, f) require.LessOrEqual(t, f.Length(protocol.Version1), maxSize) } t.Logf("received %d frames", len(frames)) for _, f := range frames { t.Logf("offset %d: %d bytes", f.Offset, len(f.Data)) if expectedServerName != "" { require.NotContainsf(t, string(f.Data), expectedServerName, "frame at offset %d contains the server name", f.Offset) } segments[f.Offset] = f.Data } reassembled := reassembleCryptoData(t, segments) require.Equal(t, append(clientHello, []byte("foobar")...), reassembled) if expectedServerName != "" { require.Contains(t, string(reassembled), expectedServerName) } } quic-go-0.59.0/crypto_stream_manager.go000066400000000000000000000042501513066070600200720ustar00rootroot00000000000000package quic import ( "fmt" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) type cryptoStreamManager struct { initialStream *initialCryptoStream handshakeStream *cryptoStream oneRTTStream *cryptoStream } func newCryptoStreamManager( initialStream *initialCryptoStream, handshakeStream *cryptoStream, oneRTTStream *cryptoStream, ) *cryptoStreamManager { return &cryptoStreamManager{ initialStream: initialStream, handshakeStream: handshakeStream, oneRTTStream: oneRTTStream, } } func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. switch encLevel { case protocol.EncryptionInitial: return m.initialStream.HandleCryptoFrame(frame) case protocol.EncryptionHandshake: return m.handshakeStream.HandleCryptoFrame(frame) case protocol.Encryption1RTT: return m.oneRTTStream.HandleCryptoFrame(frame) default: return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) } } func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte { //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. switch encLevel { case protocol.EncryptionInitial: return m.initialStream.GetCryptoData() case protocol.EncryptionHandshake: return m.handshakeStream.GetCryptoData() case protocol.Encryption1RTT: return m.oneRTTStream.GetCryptoData() default: panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel)) } } func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame { if !m.oneRTTStream.HasData() { return nil } return m.oneRTTStream.PopCryptoFrame(maxSize) } func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error { //nolint:exhaustive // 1-RTT keys should never get dropped. switch encLevel { case protocol.EncryptionInitial: return m.initialStream.Finish() case protocol.EncryptionHandshake: return m.handshakeStream.Finish() default: panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel)) } } quic-go-0.59.0/crypto_stream_manager_test.go000066400000000000000000000054661513066070600211430ustar00rootroot00000000000000package quic import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestCryptoStreamManager(t *testing.T) { t.Run("Initial", func(t *testing.T) { testCryptoStreamManager(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testCryptoStreamManager(t, protocol.EncryptionHandshake) }) t.Run("1-RTT", func(t *testing.T) { testCryptoStreamManager(t, protocol.Encryption1RTT) }) } func testCryptoStreamManager(t *testing.T, encLevel protocol.EncryptionLevel) { initialStream := newInitialCryptoStream(true) handshakeStream := newCryptoStream() oneRTTStream := newCryptoStream() csm := newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream) require.NoError(t, csm.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")}, encLevel)) require.NoError(t, csm.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}, encLevel)) var data []byte for { b := csm.GetCryptoData(encLevel) if len(b) == 0 { break } data = append(data, b...) } require.Equal(t, []byte("foobar"), data) } func TestCryptoStreamManagerInvalidEncryptionLevel(t *testing.T) { csm := newCryptoStreamManager(nil, nil, nil) require.ErrorContains(t, csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption0RTT), "received CRYPTO frame with unexpected encryption level", ) } func TestCryptoStreamManagerDropEncryptionLevel(t *testing.T) { t.Run("Initial", func(t *testing.T) { testCryptoStreamManagerDropEncryptionLevel(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testCryptoStreamManagerDropEncryptionLevel(t, protocol.EncryptionHandshake) }) } func testCryptoStreamManagerDropEncryptionLevel(t *testing.T, encLevel protocol.EncryptionLevel) { initialStream := newInitialCryptoStream(true) handshakeStream := newCryptoStream() oneRTTStream := newCryptoStream() csm := newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream) require.NoError(t, csm.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")}, encLevel)) require.ErrorContains(t, csm.Drop(encLevel), "encryption level changed, but crypto stream has more data to read") require.Equal(t, []byte("foo"), csm.GetCryptoData(encLevel)) require.NoError(t, csm.Drop(encLevel)) } func TestCryptoStreamManagerPostHandshake(t *testing.T) { initialStream := newInitialCryptoStream(true) handshakeStream := newCryptoStream() oneRTTStream := newCryptoStream() csm := newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream) _, err := oneRTTStream.Write([]byte("foo")) require.NoError(t, err) _, err = oneRTTStream.Write([]byte("bar")) require.NoError(t, err) require.Equal(t, &wire.CryptoFrame{Data: []byte("foobar")}, csm.GetPostHandshakeData(protocol.ByteCount(10)), ) } quic-go-0.59.0/crypto_stream_test.go000066400000000000000000000141461513066070600174440ustar00rootroot00000000000000package quic import ( "os" "strconv" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestCryptoStreamDataAssembly(t *testing.T) { str := newCryptoStream() require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})) // receive a retransmission require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) var data []byte for { b := str.GetCryptoData() if b == nil { break } data = append(data, b...) } require.Equal(t, []byte("foobar"), data) } func TestCryptoStreamMaxOffset(t *testing.T) { str := newCryptoStream() require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{ Offset: protocol.MaxCryptoStreamOffset - 5, Data: []byte("foo"), })) require.ErrorIs(t, str.HandleCryptoFrame(&wire.CryptoFrame{ Offset: protocol.MaxCryptoStreamOffset - 2, Data: []byte("bar"), }), &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}, ) } func TestCryptoStreamFinishWithQueuedData(t *testing.T) { t.Run("with data at current offset", func(t *testing.T) { str := newCryptoStream() require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})) require.Equal(t, []byte("foo"), str.GetCryptoData()) require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) require.ErrorIs(t, str.Finish(), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) }) t.Run("with data at a higher offset", func(t *testing.T) { str := newCryptoStream() require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foobar"), Offset: 20})) require.ErrorIs(t, str.Finish(), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) }) } func TestCryptoStreamReceiveDataAfterFinish(t *testing.T) { str := newCryptoStream() require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foobar")})) require.Equal(t, []byte("foobar"), str.GetCryptoData()) require.NoError(t, str.Finish()) // receiving a retransmission is ok require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) // but receiving new data is not require.ErrorIs(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("baz"), Offset: 4}), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, ) } func expectedCryptoFrameLen(offset protocol.ByteCount) protocol.ByteCount { f := &wire.CryptoFrame{Offset: offset} return f.Length(protocol.Version1) } func TestCryptoStreamWrite(t *testing.T) { str := newCryptoStream() require.False(t, str.HasData()) _, err := str.Write([]byte("foo")) require.NoError(t, err) require.True(t, str.HasData()) _, err = str.Write([]byte("bar")) require.NoError(t, err) _, err = str.Write([]byte("baz")) require.NoError(t, err) require.True(t, str.HasData()) for i := range expectedCryptoFrameLen(0) { require.Nil(t, str.PopCryptoFrame(i)) } f := str.PopCryptoFrame(expectedCryptoFrameLen(0) + 1) require.Equal(t, &wire.CryptoFrame{Data: []byte("f")}, f) require.True(t, str.HasData()) f = str.PopCryptoFrame(expectedCryptoFrameLen(1) + 3) // the three write calls were coalesced into a single frame require.Equal(t, &wire.CryptoFrame{Offset: 1, Data: []byte("oob")}, f) f = str.PopCryptoFrame(protocol.MaxByteCount) require.Equal(t, &wire.CryptoFrame{Offset: 4, Data: []byte("arbaz")}, f) require.False(t, str.HasData()) } func TestInitialCryptoStreamServer(t *testing.T) { str := newInitialCryptoStream(false) _, err := str.Write([]byte("foobar")) require.NoError(t, err) f := str.PopCryptoFrame(expectedCryptoFrameLen(0) + 3) require.Equal(t, &wire.CryptoFrame{Offset: 0, Data: []byte("foo")}, f) require.True(t, str.HasData()) // append another CRYPTO frame to the existing slice f = str.PopCryptoFrame(expectedCryptoFrameLen(3) + 3) require.Equal(t, &wire.CryptoFrame{Offset: 3, Data: []byte("bar")}, f) require.False(t, str.HasData()) } func reassembleCryptoData(t *testing.T, segments map[protocol.ByteCount][]byte) []byte { t.Helper() var reassembled []byte var offset protocol.ByteCount for len(segments) > 0 { b, ok := segments[offset] if !ok { break } reassembled = append(reassembled, b...) delete(segments, offset) offset = protocol.ByteCount(len(reassembled)) } require.Empty(t, segments) return reassembled } func skipIfDisableScramblingEnvSet(t *testing.T) { t.Helper() disabled, err := strconv.ParseBool(os.Getenv(disableClientHelloScramblingEnv)) if err == nil && disabled { t.Skip("ClientHello scrambling disabled via " + disableClientHelloScramblingEnv) } } func TestInitialCryptoStreamClientStatic(t *testing.T) { skipIfDisableScramblingEnvSet(t) str := newInitialCryptoStream(true) clientHello := getClientHello(t, "quic-go.net") _, err := str.Write(clientHello) require.NoError(t, err) require.True(t, str.HasData()) _, err = str.Write([]byte("foobar")) require.NoError(t, err) segments := make(map[protocol.ByteCount][]byte) f1 := str.PopCryptoFrame(protocol.MaxByteCount) require.NotNil(t, f1) segments[f1.Offset] = f1.Data require.True(t, str.HasData()) f2 := str.PopCryptoFrame(protocol.MaxByteCount) require.NotNil(t, f2) require.NotContains(t, segments, f2.Offset) segments[f2.Offset] = f2.Data require.True(t, str.HasData()) require.NotEqual(t, f2.Offset, protocol.ByteCount(len(f1.Data))) f3 := str.PopCryptoFrame(protocol.MaxByteCount) require.NotNil(t, f2) require.NotContains(t, segments, f3.Offset) segments[f3.Offset] = f3.Data require.True(t, str.HasData()) require.NotEqual(t, f3.Offset, protocol.ByteCount(len(f2.Data))) f4 := str.PopCryptoFrame(protocol.MaxByteCount) require.NotNil(t, f4) require.NotContains(t, segments, f4.Offset) segments[f4.Offset] = f4.Data require.Equal(t, []byte("foobar"), f4.Data) require.False(t, str.HasData()) require.NotEqual(t, f4.Offset, protocol.ByteCount(len(f3.Data))) reassembled := reassembleCryptoData(t, segments) require.Equal(t, append(clientHello, []byte("foobar")...), reassembled) } quic-go-0.59.0/datagram_queue.go000066400000000000000000000056401513066070600164750ustar00rootroot00000000000000package quic import ( "context" "sync" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" ) const ( maxDatagramSendQueueLen = 32 maxDatagramRcvQueueLen = 128 ) type datagramQueue struct { sendMx sync.Mutex sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame] sent chan struct{} // used to notify Add that a datagram was dequeued rcvMx sync.Mutex rcvQueue [][]byte rcvd chan struct{} // used to notify Receive that a new datagram was received closeErr error closed chan struct{} hasData func() logger utils.Logger } func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { return &datagramQueue{ hasData: hasData, rcvd: make(chan struct{}, 1), sent: make(chan struct{}, 1), closed: make(chan struct{}), logger: logger, } } // Add queues a new DATAGRAM frame for sending. // Up to 32 DATAGRAM frames will be queued. // Once that limit is reached, Add blocks until the queue size has reduced. func (h *datagramQueue) Add(f *wire.DatagramFrame) error { h.sendMx.Lock() for { if h.sendQueue.Len() < maxDatagramSendQueueLen { h.sendQueue.PushBack(f) h.sendMx.Unlock() h.hasData() return nil } select { case <-h.sent: // drain the queue so we don't loop immediately default: } h.sendMx.Unlock() select { case <-h.closed: return h.closeErr case <-h.sent: } h.sendMx.Lock() } } // Peek gets the next DATAGRAM frame for sending. // If actually sent out, Pop needs to be called before the next call to Peek. func (h *datagramQueue) Peek() *wire.DatagramFrame { h.sendMx.Lock() defer h.sendMx.Unlock() if h.sendQueue.Empty() { return nil } return h.sendQueue.PeekFront() } func (h *datagramQueue) Pop() { h.sendMx.Lock() defer h.sendMx.Unlock() _ = h.sendQueue.PopFront() select { case h.sent <- struct{}{}: default: } } // HandleDatagramFrame handles a received DATAGRAM frame. func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { data := make([]byte, len(f.Data)) copy(data, f.Data) var queued bool h.rcvMx.Lock() if len(h.rcvQueue) < maxDatagramRcvQueueLen { h.rcvQueue = append(h.rcvQueue, data) queued = true select { case h.rcvd <- struct{}{}: default: } } h.rcvMx.Unlock() if !queued && h.logger.Debug() { h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data)) } } // Receive gets a received DATAGRAM frame. func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) { for { h.rcvMx.Lock() if len(h.rcvQueue) > 0 { data := h.rcvQueue[0] h.rcvQueue = h.rcvQueue[1:] h.rcvMx.Unlock() return data, nil } h.rcvMx.Unlock() select { case <-h.rcvd: continue case <-h.closed: return nil, h.closeErr case <-ctx.Done(): return nil, ctx.Err() } } } func (h *datagramQueue) CloseWithError(e error) { h.closeErr = e close(h.closed) } quic-go-0.59.0/datagram_queue_test.go000066400000000000000000000107131513066070600175310ustar00rootroot00000000000000package quic import ( "context" "testing" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDatagramQueuePeekAndPop(t *testing.T) { var queued []struct{} queue := newDatagramQueue(func() { queued = append(queued, struct{}{}) }, utils.DefaultLogger) require.Nil(t, queue.Peek()) require.Empty(t, queued) require.NoError(t, queue.Add(&wire.DatagramFrame{Data: []byte("foo")})) require.Len(t, queued, 1) require.Equal(t, &wire.DatagramFrame{Data: []byte("foo")}, queue.Peek()) // calling peek again returns the same datagram require.Equal(t, &wire.DatagramFrame{Data: []byte("foo")}, queue.Peek()) queue.Pop() require.Nil(t, queue.Peek()) } func TestDatagramQueueSendQueueLength(t *testing.T) { synctest.Test(t, func(t *testing.T) { queue := newDatagramQueue(func() {}, utils.DefaultLogger) for range maxDatagramSendQueueLen { require.NoError(t, queue.Add(&wire.DatagramFrame{Data: []byte{0}})) } errChan := make(chan error, 1) go func() { errChan <- queue.Add(&wire.DatagramFrame{Data: []byte("foobar")}) }() synctest.Wait() select { case <-errChan: t.Fatal("expected to not receive error") default: } // peeking doesn't remove the datagram from the queue... require.NotNil(t, queue.Peek()) synctest.Wait() select { case <-errChan: t.Fatal("expected to not receive error") default: } // ...but popping does queue.Pop() synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("timeout") } // pop all the remaining datagrams for range maxDatagramSendQueueLen - 1 { queue.Pop() } f := queue.Peek() require.NotNil(t, f) require.Equal(t, &wire.DatagramFrame{Data: []byte("foobar")}, f) }) } func TestDatagramQueueReceive(t *testing.T) { queue := newDatagramQueue(func() {}, utils.DefaultLogger) // receive frames that were received earlier queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) data, err := queue.Receive(context.Background()) require.NoError(t, err) require.Equal(t, []byte("foo"), data) data, err = queue.Receive(context.Background()) require.NoError(t, err) require.Equal(t, []byte("bar"), data) } func TestDatagramQueueReceiveBlocking(t *testing.T) { synctest.Test(t, func(t *testing.T) { queue := newDatagramQueue(func() {}, utils.DefaultLogger) // block until a new frame is received type result struct { data []byte err error } resultChan := make(chan result, 1) go func() { data, err := queue.Receive(context.Background()) resultChan <- result{data, err} }() synctest.Wait() select { case <-resultChan: t.Fatal("expected to not receive result") default: } queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")}) synctest.Wait() select { case result := <-resultChan: require.NoError(t, result.err) require.Equal(t, []byte("foobar"), result.data) default: t.Fatal("should have received a datagram frame") } // unblock when the context is canceled ctx, cancel := context.WithCancel(context.Background()) errChan := make(chan error, 1) go func() { _, err := queue.Receive(ctx) errChan <- err }() synctest.Wait() select { case <-errChan: t.Fatal("expected to not receive error") default: } cancel() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, context.Canceled) default: t.Fatal("should have received a context canceled error") } }) } func TestDatagramQueueClose(t *testing.T) { synctest.Test(t, func(t *testing.T) { queue := newDatagramQueue(func() {}, utils.DefaultLogger) for range maxDatagramSendQueueLen { require.NoError(t, queue.Add(&wire.DatagramFrame{Data: []byte{0}})) } errChan1 := make(chan error, 1) go func() { errChan1 <- queue.Add(&wire.DatagramFrame{Data: []byte("foobar")}) }() errChan2 := make(chan error, 1) go func() { _, err := queue.Receive(context.Background()) errChan2 <- err }() queue.CloseWithError(assert.AnError) synctest.Wait() select { case err := <-errChan1: require.ErrorIs(t, err, assert.AnError) default: t.Fatal("should have received an error") } select { case err := <-errChan2: require.ErrorIs(t, err, assert.AnError) default: t.Fatal("should have received an error") } }) } quic-go-0.59.0/errors.go000066400000000000000000000105351513066070600150240ustar00rootroot00000000000000package quic import ( "fmt" "github.com/quic-go/quic-go/internal/qerr" ) type ( // TransportError indicates an error that occurred on the QUIC transport layer. // Every transport error other than CONNECTION_REFUSED and APPLICATION_ERROR is // likely a bug in the implementation. TransportError = qerr.TransportError // ApplicationError is an application-defined error. ApplicationError = qerr.ApplicationError // VersionNegotiationError indicates a failure to negotiate a QUIC version. VersionNegotiationError = qerr.VersionNegotiationError // StatelessResetError indicates a stateless reset was received. // This can happen when the peer reboots, or when packets are misrouted. // See section 10.3 of RFC 9000 for details. StatelessResetError = qerr.StatelessResetError // IdleTimeoutError indicates that the connection timed out because it was inactive for too long. IdleTimeoutError = qerr.IdleTimeoutError // HandshakeTimeoutError indicates that the connection timed out before completing the handshake. HandshakeTimeoutError = qerr.HandshakeTimeoutError ) type ( // TransportErrorCode is a QUIC transport error code, see section 20 of RFC 9000. TransportErrorCode = qerr.TransportErrorCode // ApplicationErrorCode is an QUIC application error code. ApplicationErrorCode = qerr.ApplicationErrorCode // StreamErrorCode is a QUIC stream error code. The meaning of the value is defined by the application. StreamErrorCode = qerr.StreamErrorCode ) const ( // NoError is the NO_ERROR transport error code. NoError = qerr.NoError // InternalError is the INTERNAL_ERROR transport error code. InternalError = qerr.InternalError // ConnectionRefused is the CONNECTION_REFUSED transport error code. ConnectionRefused = qerr.ConnectionRefused // FlowControlError is the FLOW_CONTROL_ERROR transport error code. FlowControlError = qerr.FlowControlError // StreamLimitError is the STREAM_LIMIT_ERROR transport error code. StreamLimitError = qerr.StreamLimitError // StreamStateError is the STREAM_STATE_ERROR transport error code. StreamStateError = qerr.StreamStateError // FinalSizeError is the FINAL_SIZE_ERROR transport error code. FinalSizeError = qerr.FinalSizeError // FrameEncodingError is the FRAME_ENCODING_ERROR transport error code. FrameEncodingError = qerr.FrameEncodingError // TransportParameterError is the TRANSPORT_PARAMETER_ERROR transport error code. TransportParameterError = qerr.TransportParameterError // ConnectionIDLimitError is the CONNECTION_ID_LIMIT_ERROR transport error code. ConnectionIDLimitError = qerr.ConnectionIDLimitError // ProtocolViolation is the PROTOCOL_VIOLATION transport error code. ProtocolViolation = qerr.ProtocolViolation // InvalidToken is the INVALID_TOKEN transport error code. InvalidToken = qerr.InvalidToken // ApplicationErrorErrorCode is the APPLICATION_ERROR transport error code. ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode // CryptoBufferExceeded is the CRYPTO_BUFFER_EXCEEDED transport error code. CryptoBufferExceeded = qerr.CryptoBufferExceeded // KeyUpdateError is the KEY_UPDATE_ERROR transport error code. KeyUpdateError = qerr.KeyUpdateError // AEADLimitReached is the AEAD_LIMIT_REACHED transport error code. AEADLimitReached = qerr.AEADLimitReached // NoViablePathError is the NO_VIABLE_PATH_ERROR transport error code. NoViablePathError = qerr.NoViablePathError ) // A StreamError is used to signal stream cancellations. // It is returned from the Read and Write methods of the [ReceiveStream], [SendStream] and [Stream]. type StreamError struct { StreamID StreamID ErrorCode StreamErrorCode Remote bool } func (e *StreamError) Is(target error) bool { t, ok := target.(*StreamError) return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote } func (e *StreamError) Error() string { pers := "local" if e.Remote { pers = "remote" } return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode) } // DatagramTooLargeError is returned from Conn.SendDatagram if the payload is too large to be sent. type DatagramTooLargeError struct { MaxDatagramPayloadSize int64 } func (e *DatagramTooLargeError) Is(target error) bool { t, ok := target.(*DatagramTooLargeError) return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize } func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" } quic-go-0.59.0/errors_test.go000066400000000000000000000022251513066070600160600ustar00rootroot00000000000000package quic import ( "errors" "testing" "github.com/stretchr/testify/require" ) func TestStreamError(t *testing.T) { require.True(t, errors.Is( &StreamError{StreamID: 1, ErrorCode: 2, Remote: true}, &StreamError{StreamID: 1, ErrorCode: 2, Remote: true}, )) require.False(t, errors.Is(&StreamError{StreamID: 1}, &StreamError{StreamID: 2})) require.False(t, errors.Is(&StreamError{StreamID: 1}, &StreamError{StreamID: 2})) require.Equal(t, "stream 1 canceled by remote with error code 2", (&StreamError{StreamID: 1, ErrorCode: 2, Remote: true}).Error(), ) require.Equal(t, "stream 42 canceled by local with error code 1337", (&StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}).Error(), ) } func TestDatagramTooLargeError(t *testing.T) { require.True(t, errors.Is( &DatagramTooLargeError{MaxDatagramPayloadSize: 1024}, &DatagramTooLargeError{MaxDatagramPayloadSize: 1024}, )) require.False(t, errors.Is( &DatagramTooLargeError{MaxDatagramPayloadSize: 1024}, &DatagramTooLargeError{MaxDatagramPayloadSize: 1025}, )) require.Equal(t, "DATAGRAM frame too large", (&DatagramTooLargeError{MaxDatagramPayloadSize: 1024}).Error()) } quic-go-0.59.0/example/000077500000000000000000000000001513066070600146105ustar00rootroot00000000000000quic-go-0.59.0/example/client/000077500000000000000000000000001513066070600160665ustar00rootroot00000000000000quic-go-0.59.0/example/client/main.go000066400000000000000000000031641513066070600173450ustar00rootroot00000000000000package main import ( "bytes" "crypto/tls" "crypto/x509" "flag" "io" "log" "net/http" "os" "sync" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/internal/testdata" ) func main() { quiet := flag.Bool("q", false, "don't print the data") keyLogFile := flag.String("keylog", "", "key log file") insecure := flag.Bool("insecure", false, "skip certificate verification") flag.Parse() urls := flag.Args() var keyLog io.Writer if len(*keyLogFile) > 0 { f, err := os.Create(*keyLogFile) if err != nil { log.Fatal(err) } defer f.Close() keyLog = f } pool, err := x509.SystemCertPool() if err != nil { log.Fatal(err) } testdata.AddRootCA(pool) roundTripper := &http3.Transport{ TLSClientConfig: &tls.Config{ RootCAs: pool, InsecureSkipVerify: *insecure, KeyLogWriter: keyLog, }, QUICConfig: &quic.Config{ Tracer: qlog.DefaultConnectionTracer, }, } defer roundTripper.Close() hclient := &http.Client{ Transport: roundTripper, } var wg sync.WaitGroup wg.Add(len(urls)) for _, addr := range urls { log.Printf("GET %s", addr) go func(addr string) { rsp, err := hclient.Get(addr) if err != nil { log.Fatal(err) } log.Printf("Got response for %s: %#v", addr, rsp) body := &bytes.Buffer{} _, err = io.Copy(body, rsp.Body) if err != nil { log.Fatal(err) } if *quiet { log.Printf("Response Body: %d bytes", body.Len()) } else { log.Printf("Response Body (%d bytes):\n%s", body.Len(), body.Bytes()) } wg.Done() }(addr) } wg.Wait() } quic-go-0.59.0/example/echo/000077500000000000000000000000001513066070600155265ustar00rootroot00000000000000quic-go-0.59.0/example/echo/echo.go000066400000000000000000000045711513066070600170020ustar00rootroot00000000000000package main import ( "context" "crypto/ed25519" "crypto/rand" "crypto/tls" "crypto/x509" "fmt" "io" "log" "math/big" "github.com/quic-go/quic-go" ) const addr = "localhost:4242" const message = "foobar" // We start a server echoing data on the first stream the client opens, // then connect with a client, send the message, and wait for its receipt. func main() { go func() { log.Fatal(echoServer()) }() if err := clientMain(); err != nil { panic(err) } } // Start a server that echos all data on the first stream opened by the client func echoServer() error { listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil) if err != nil { return err } defer listener.Close() conn, err := listener.Accept(context.Background()) if err != nil { return err } stream, err := conn.AcceptStream(context.Background()) if err != nil { panic(err) } defer stream.Close() // Echo through the loggingWriter _, err = io.Copy(loggingWriter{stream}, stream) return err } func clientMain() error { tlsConf := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"quic-echo-example"}, } conn, err := quic.DialAddr(context.Background(), addr, tlsConf, nil) if err != nil { return err } defer conn.CloseWithError(0, "") stream, err := conn.OpenStreamSync(context.Background()) if err != nil { return err } defer stream.Close() fmt.Printf("Client: Sending '%s'\n", message) if _, err := stream.Write([]byte(message)); err != nil { return err } buf := make([]byte, len(message)) if _, err := io.ReadFull(stream, buf); err != nil { return err } fmt.Printf("Client: Got '%s'\n", buf) return nil } // A wrapper for io.Writer that also logs the message. type loggingWriter struct{ io.Writer } func (w loggingWriter) Write(b []byte) (int, error) { fmt.Printf("Server: Got '%s'\n", string(b)) return w.Writer.Write(b) } // Setup a bare-bones TLS config for the server func generateTLSConfig() *tls.Config { _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { panic(err) } template := x509.Certificate{SerialNumber: big.NewInt(1)} certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) if err != nil { panic(err) } return &tls.Config{ Certificates: []tls.Certificate{{ Certificate: [][]byte{certDER}, PrivateKey: priv, }}, NextProtos: []string{"quic-echo-example"}, } } quic-go-0.59.0/example/main.go000066400000000000000000000111711513066070600160640ustar00rootroot00000000000000package main import ( "crypto/md5" "errors" "flag" "fmt" "io" "log" "mime/multipart" "net/http" "strconv" "strings" "sync" _ "net/http/pprof" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/internal/testdata" ) type binds []string func (b binds) String() string { return strings.Join(b, ",") } func (b *binds) Set(v string) error { *b = strings.Split(v, ",") return nil } // Size is needed by the /demo/upload handler to determine the size of the uploaded file type Size interface { Size() int64 } // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator func generatePRData(l int) []byte { res := make([]byte, l) seed := uint64(1) for i := 0; i < l; i++ { seed = seed * 48271 % 2147483647 res[i] = byte(seed) } return res } func setupHandler(www string) http.Handler { mux := http.NewServeMux() if len(www) > 0 { mux.Handle("/", http.FileServer(http.Dir(www))) } else { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { fmt.Printf("%#v\n", r) const maxSize = 1 << 30 // 1 GB num, err := strconv.ParseInt(strings.ReplaceAll(r.RequestURI, "/", ""), 10, 64) if err != nil || num <= 0 || num > maxSize { w.WriteHeader(400) return } w.Write(generatePRData(int(num))) }) } mux.HandleFunc("/demo/tile", func(w http.ResponseWriter, r *http.Request) { // Small 40x40 png w.Write([]byte{ 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x28, 0x01, 0x03, 0x00, 0x00, 0x00, 0xb6, 0x30, 0x2a, 0x2e, 0x00, 0x00, 0x00, 0x03, 0x50, 0x4c, 0x54, 0x45, 0x5a, 0xc3, 0x5a, 0xad, 0x38, 0xaa, 0xdb, 0x00, 0x00, 0x00, 0x0b, 0x49, 0x44, 0x41, 0x54, 0x78, 0x01, 0x63, 0x18, 0x61, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x01, 0xe2, 0xb8, 0x75, 0x22, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82, }) }) mux.HandleFunc("/demo/tiles", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "") for i := 0; i < 200; i++ { fmt.Fprintf(w, ``, i) } io.WriteString(w, "") }) mux.HandleFunc("/demo/echo", func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { fmt.Printf("error reading body while handling /echo: %s\n", err.Error()) } w.Write(body) }) // accept file uploads and return the MD5 of the uploaded file // maximum accepted file size is 1 GB mux.HandleFunc("/demo/upload", func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { err := r.ParseMultipartForm(1 << 30) // 1 GB if err == nil { var file multipart.File file, _, err = r.FormFile("uploadfile") if err == nil { var size int64 if sizeInterface, ok := file.(Size); ok { size = sizeInterface.Size() b := make([]byte, size) file.Read(b) md5 := md5.Sum(b) fmt.Fprintf(w, "%x", md5) return } err = errors.New("couldn't get uploaded file size") } } log.Printf("Error receiving upload: %#v", err) } io.WriteString(w, `

`) }) return mux } func main() { // defer profile.Start().Stop() go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() // runtime.SetBlockProfileRate(1) bs := binds{} flag.Var(&bs, "bind", "bind to") www := flag.String("www", "", "www data") tcp := flag.Bool("tcp", false, "also listen on TCP") key := flag.String("key", "", "TLS key (requires -cert option)") cert := flag.String("cert", "", "TLS certificate (requires -key option)") flag.Parse() if len(bs) == 0 { bs = binds{"localhost:6121"} } handler := setupHandler(*www) var wg sync.WaitGroup wg.Add(len(bs)) var certFile, keyFile string if *key != "" && *cert != "" { keyFile = *key certFile = *cert } else { certFile, keyFile = testdata.GetCertificatePaths() } for _, b := range bs { fmt.Println("listening on", b) bCap := b go func() { var err error if *tcp { err = http3.ListenAndServeTLS(bCap, certFile, keyFile, handler) } else { server := http3.Server{ Handler: handler, Addr: bCap, QUICConfig: &quic.Config{ Tracer: qlog.DefaultConnectionTracer, }, } err = server.ListenAndServeTLS(certFile, keyFile) } if err != nil { fmt.Println(err) } wg.Done() }() } wg.Wait() } quic-go-0.59.0/frame_sorter.go000066400000000000000000000160301513066070600161740ustar00rootroot00000000000000package quic import ( "errors" "sync" "github.com/quic-go/quic-go/internal/protocol" list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) // byteInterval is an interval from one ByteCount to the other type byteInterval struct { Start protocol.ByteCount End protocol.ByteCount } var byteIntervalElementPool sync.Pool func init() { byteIntervalElementPool = *list.NewPool[byteInterval]() } type frameSorterEntry struct { Data []byte DoneCb func() } type frameSorter struct { queue map[protocol.ByteCount]frameSorterEntry readPos protocol.ByteCount gaps *list.List[byteInterval] } var errDuplicateStreamData = errors.New("duplicate stream data") func newFrameSorter() *frameSorter { s := frameSorter{ gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool), queue: make(map[protocol.ByteCount]frameSorterEntry), } s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error { err := s.push(data, offset, doneCb) if err == errDuplicateStreamData { if doneCb != nil { doneCb() } return nil } return err } func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error { if len(data) == 0 { return errDuplicateStreamData } start := offset end := offset + protocol.ByteCount(len(data)) if end <= s.gaps.Front().Value.Start { return errDuplicateStreamData } startGap, startsInGap := s.findStartGap(start) endGap, endsInGap := s.findEndGap(startGap, end) startGapEqualsEndGap := startGap == endGap if (startGapEqualsEndGap && end <= startGap.Value.Start) || (!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) { return errDuplicateStreamData } startGapNext := startGap.Next() startGapEnd := startGap.Value.End // save it, in case startGap is modified endGapStart := endGap.Value.Start // save it, in case endGap is modified endGapEnd := endGap.Value.End // save it, in case endGap is modified var adjustedStartGapEnd bool var wasCut bool pos := start var hasReplacedAtLeastOne bool for { oldEntry, ok := s.queue[pos] if !ok { break } oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) { // The existing frame is shorter than the new frame. Replace it. delete(s.queue, pos) pos += oldEntryLen hasReplacedAtLeastOne = true if oldEntry.DoneCb != nil { oldEntry.DoneCb() } } else { if !hasReplacedAtLeastOne { return errDuplicateStreamData } // The existing frame is longer than the new frame. // Cut the new frame such that the end aligns with the start of the existing frame. data = data[:pos-start] end = pos wasCut = true break } } if !startsInGap && !hasReplacedAtLeastOne { // cut the frame, such that it starts at the start of the gap data = data[startGap.Value.Start-start:] start = startGap.Value.Start wasCut = true } if start <= startGap.Value.Start { if end >= startGap.Value.End { // The frame covers the whole startGap. Delete the gap. s.gaps.Remove(startGap) } else { startGap.Value.Start = end } } else if !hasReplacedAtLeastOne { startGap.Value.End = start adjustedStartGapEnd = true } if !startGapEqualsEndGap { s.deleteConsecutive(startGapEnd) var nextGap *list.Element[byteInterval] for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap { nextGap = gap.Next() s.deleteConsecutive(gap.Value.End) s.gaps.Remove(gap) } } if !endsInGap && start != endGapEnd && end > endGapEnd { // cut the frame, such that it ends at the end of the gap data = data[:endGapEnd-start] end = endGapEnd wasCut = true } if end == endGapEnd { if !startGapEqualsEndGap { // The frame covers the whole endGap. Delete the gap. s.gaps.Remove(endGap) } } else { if startGapEqualsEndGap && adjustedStartGapEnd { // The frame split the existing gap into two. s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap) } else if !startGapEqualsEndGap { endGap.Value.Start = end } } if wasCut && len(data) < protocol.MinStreamFrameBufferSize { newData := make([]byte, len(data)) copy(newData, data) data = newData if doneCb != nil { doneCb() doneCb = nil } } if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps { return errors.New("too many gaps in received data") } s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb} return nil } func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) { for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { if offset >= gap.Value.Start && offset <= gap.Value.End { return gap, true } if offset < gap.Value.Start { return gap, false } } panic("no gap found") } func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) { for gap := startGap; gap != nil; gap = gap.Next() { if offset >= gap.Value.Start && offset < gap.Value.End { return gap, true } if offset < gap.Value.Start { return gap.Prev(), false } } panic("no gap found") } // deleteConsecutive deletes consecutive frames from the queue, starting at pos func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) { for { oldEntry, ok := s.queue[pos] if !ok { break } oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) delete(s.queue, pos) if oldEntry.DoneCb != nil { oldEntry.DoneCb() } pos += oldEntryLen } } func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) { entry, ok := s.queue[s.readPos] if !ok { return s.readPos, nil, nil } delete(s.queue, s.readPos) offset := s.readPos s.readPos += protocol.ByteCount(len(entry.Data)) if s.gaps.Front().Value.End <= s.readPos { panic("frame sorter BUG: read position higher than a gap") } return offset, entry.Data, entry.DoneCb } // HasMoreData says if there is any more data queued at *any* offset. func (s *frameSorter) HasMoreData() bool { return len(s.queue) > 0 } var errTooLittleData = errors.New("too little data") // Peek copies len(p) consecutive bytes starting at offset into p, without removing them. // It is only possible to peek from an offset where a frame starts. // // If there isn't enough consecutive data available, errTooLittleData is returned. func (s *frameSorter) Peek(offset protocol.ByteCount, p []byte) error { if len(p) == 0 { return nil } // first, check if we have enough consecutive data available pos := offset remaining := len(p) for remaining > 0 { entry, ok := s.queue[pos] if !ok { return errTooLittleData } entryLen := len(entry.Data) if remaining <= entryLen { break // enough data available } remaining -= entryLen pos += protocol.ByteCount(entryLen) } pos = offset var copied int for copied < len(p) { entry := s.queue[pos] // the entry is guaranteed to exist from the check above copied += copy(p[copied:], entry.Data) pos += protocol.ByteCount(len(entry.Data)) } return nil } quic-go-0.59.0/frame_sorter_test.go000066400000000000000000001236211513066070600172400ustar00rootroot00000000000000package quic import ( rand "crypto/rand" "fmt" "math" mrand "math/rand/v2" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) type callbackTracker struct { called *bool cb func() } func (t *callbackTracker) WasCalled() bool { return *t.called } func getFrameSorterTestCallback(t *testing.T) (func(), callbackTracker) { var called bool cb := func() { if called { t.Fatal("double free") } called = true } return cb, callbackTracker{ cb: cb, called: &called, } } func TestFrameSorterSimpleCases(t *testing.T) { s := newFrameSorter() _, data, doneCb := s.Pop() require.Nil(t, data) require.Nil(t, doneCb) // empty frames are ignored require.NoError(t, s.Push(nil, 0, nil)) _, data, doneCb = s.Pop() require.Nil(t, data) require.Nil(t, doneCb) cb1, t1 := getFrameSorterTestCallback(t) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push([]byte("bar"), 3, cb2)) require.True(t, s.HasMoreData()) require.NoError(t, s.Push([]byte("foo"), 0, cb1)) offset, data, doneCb := s.Pop() require.Equal(t, []byte("foo"), data) require.Zero(t, offset) require.NotNil(t, doneCb) doneCb() require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, s.HasMoreData()) offset, data, doneCb = s.Pop() require.Equal(t, []byte("bar"), data) require.Equal(t, protocol.ByteCount(3), offset) require.NotNil(t, doneCb) doneCb() require.True(t, t2.WasCalled()) require.False(t, s.HasMoreData()) // now receive a duplicate cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push([]byte("foo"), 0, cb3)) require.False(t, s.HasMoreData()) require.True(t, t3.WasCalled()) // now receive a later frame that overlaps with the ones we already consumed cb4, _ := getFrameSorterTestCallback(t) require.NoError(t, s.Push([]byte("barbaz"), 3, cb4)) require.True(t, s.HasMoreData()) offset, data, _ = s.Pop() require.Equal(t, protocol.ByteCount(6), offset) require.Equal(t, []byte("baz"), data) require.False(t, s.HasMoreData()) } // Usually, it's not a good idea to test the implementation details. // However, we need to make sure that the frame sorter handles gaps correctly, // in particular when overlapping stream data is received. // This also includes returning buffers that are no longer needed. func TestFrameSorterGapHandling(t *testing.T) { random := mrand.NewChaCha8([32]byte{'f', 'o', 'o', 'b', 'a', 'r'}) getData := func(l protocol.ByteCount) []byte { b := make([]byte, l) random.Read(b) return b } checkQueue := func(t *testing.T, s *frameSorter, m map[protocol.ByteCount][]byte) { require.Equal(t, len(m), len(s.queue)) for offset, data := range m { require.Contains(t, s.queue, offset) require.Equal(t, data, s.queue[offset].Data) } } checkGaps := func(t *testing.T, s *frameSorter, expectedGaps []byteInterval) { if s.gaps.Len() != len(expectedGaps) { fmt.Println("Gaps:") for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { fmt.Printf("\t%d - %d\n", gap.Value.Start, gap.Value.End) } require.Equal(t, len(expectedGaps), s.gaps.Len()) } var i int for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { require.Equal(t, expectedGaps[i], gap.Value) i++ } } // ---xxx-------------- // ++++++ // => // ---xxx++++++-------- t.Run("case 1", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 11, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ---xxx----------------- // +++++++ // => // ---xxx---+++++++-------- t.Run("case 2", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 10, cb2)) // 10 -15 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 10: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 6, End: 10}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ---xxx----xxxxxx------- // ++++ // => // ---xxx++++xxxxx-------- t.Run("case 3", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f3, 10, cb2)) // 10 - 15 require.NoError(t, s.Push(f2, 6, cb3)) // 6 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f2, 10: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ----xxxx------- // ++++ // => // ----xxxx++----- t.Run("case 4", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 7 require.NoError(t, s.Push(f2, 5, cb2)) // 5 - 9 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 7: f2[2:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 9, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) t.Run("case 4, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) f1 := getData(4 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 7 require.NoError(t, s.Push(f2, 5*mult, cb2)) // 5 - 9 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f1, 7 * mult: f2[2*mult:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 9 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // xxxx------- // ++++ // => // xxxx+++----- t.Run("case 5", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 0, cb1)) // 0 - 4 require.NoError(t, s.Push(f2, 3, cb2)) // 3 - 7 checkQueue(t, s, map[protocol.ByteCount][]byte{ 0: f1, 4: f2[1:], }) checkGaps(t, s, []byteInterval{ {Start: 7, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) t.Run("case 5, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) f1 := getData(4 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 0, cb1)) // 0 - 4 require.NoError(t, s.Push(f2, 3*mult, cb2)) // 3 - 7 checkQueue(t, s, map[protocol.ByteCount][]byte{ 0: f1, 4 * mult: f2[mult:], }) checkGaps(t, s, []byteInterval{ {Start: 7 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ----xxxx------- // ++++ // => // --++xxxx------- t.Run("case 6", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 9 require.NoError(t, s.Push(f2, 3, cb2)) // 3 - 7 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f2[:2], 5: f1, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 9, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) t.Run("case 6, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) f1 := getData(4 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5*mult, cb1)) // 5 - 9 require.NoError(t, s.Push(f2, 3*mult, cb2)) // 3 - 7 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f2[:2*mult], 5 * mult: f1, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 9 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ---xxx----xxxxxx------- // ++ // => // ---xxx++--xxxxx-------- t.Run("case 7", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(2) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f3, 10, cb2)) // 10 - 15 require.NoError(t, s.Push(f2, 6, cb3)) // 6 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f2, 10: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 8, End: 10}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx---------xxxxxx-- // ++ // => // ---xxx---++----xxxxx-- t.Run("case 8", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(2) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f3, 15, cb2)) // 15 - 20 require.NoError(t, s.Push(f2, 10, cb3)) // 10 - 12 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 10: f2, 15: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 6, End: 10}, {Start: 12, End: 15}, {Start: 20, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx----xxxxxx------- // ++ // => // ---xxx--++xxxxx-------- t.Run("case 9", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(2) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f3, 10, cb2)) // 10 - 15 require.NoError(t, s.Push(f2, 8, cb3)) // 8 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 8: f2, 10: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 6, End: 8}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx----=====------- // +++++++ // => // ---xxx++++=====-------- t.Run("case 10", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 10, cb2)) // 10 - 15 require.NoError(t, s.Push(f3, 5, cb3)) // 5 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f3[1:5], 10: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 10, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 4)) f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 10*mult, cb2)) // 10 - 15 require.NoError(t, s.Push(f3, 5*mult, cb3)) // 5 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f1, 6 * mult: f3[mult : 5*mult], 10 * mult: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 15 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxxx----=====------- // ++++++ // => // ---xxx++++=====-------- t.Run("case 11", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 7 require.NoError(t, s.Push(f2, 10, cb2)) // 10 - 15 require.NoError(t, s.Push(f3, 5, cb3)) // 5 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 7: f3[2:], 10: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) // ---xxxx----=====------- // ++++++ // => // ---xxx++++=====-------- t.Run("case 11, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 3)) f1 := getData(4 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 7 require.NoError(t, s.Push(f2, 10*mult, cb2)) // 10 - 15 require.NoError(t, s.Push(f3, 5*mult, cb3)) // 5 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f1, 7 * mult: f3[2*mult:], 10 * mult: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 15 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ----xxxx------- // +++++++ // => // ----+++++++----- t.Run("case 12", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(7) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 7 require.NoError(t, s.Push(f2, 3, cb2)) // 3 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 10, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ----xxx===------- // +++++++ // => // ----+++++++----- t.Run("case 13", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(7) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 9 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 10, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ----xxx====------- // +++++ // => // ----+++====----- t.Run("case 14", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 10 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f3[:3], 6: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 10, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 14, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 3)) f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6*mult, cb2)) // 6 - 10 require.NoError(t, s.Push(f3, 3*mult, cb3)) // 3 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f3[:3*mult], 6 * mult: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 10 * mult, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ----xxx===------- // ++++++ // => // ----++++++----- t.Run("case 15", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 9 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 9 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 9, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxxx------- // ++++ // => // ---xxxx----- t.Run("case 16", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 9 require.NoError(t, s.Push(f2, 5, cb2)) // 5 - 9 checkQueue(t, s, map[protocol.ByteCount][]byte{ 5: f1, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 5}, {Start: 9, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) // ----xxx===------- // +++ // => // ----xxx===----- t.Run("case 17", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(3) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 9 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 6 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 9, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) // ---xxxx------- // ++ // => // ---xxxx----- t.Run("case 18", func(t *testing.T) { s := newFrameSorter() f1 := getData(4) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(2) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 9 require.NoError(t, s.Push(f2, 5, cb2)) // 5 - 7 checkQueue(t, s, map[protocol.ByteCount][]byte{ 5: f1, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 5}, {Start: 9, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) // ---xxxxx------ // ++ // => // ---xxxxx---- t.Run("case 19", func(t *testing.T) { s := newFrameSorter() f1 := getData(5) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(2) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 10 checkQueue(t, s, map[protocol.ByteCount][]byte{ 5: f1, }) require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 5: f1, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 5}, {Start: 10, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) // xxxxx------ // ++ // => // xxxxx------ t.Run("case 20", func(t *testing.T) { s := newFrameSorter() f1 := getData(10) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 0, cb1)) // 0 - 10 require.NoError(t, s.Push(f2, 5, cb2)) // 5 - 9 checkQueue(t, s, map[protocol.ByteCount][]byte{ 0: f1, }) checkGaps(t, s, []byteInterval{ {Start: 10, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) // ---xxxxx--- // +++ // => // ---xxxxx--- t.Run("case 21", func(t *testing.T) { s := newFrameSorter() f1 := getData(5) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 10 require.NoError(t, s.Push(f2, 7, cb2)) // 7 - 10 checkGaps(t, s, []byteInterval{ {Start: 0, End: 5}, {Start: 10, End: protocol.MaxByteCount}, }) checkQueue(t, s, map[protocol.ByteCount][]byte{ 5: f1, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) }) // ----xxx------ // +++++ // => // --+++++---- t.Run("case 22", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 8 require.NoError(t, s.Push(f2, 3, cb2)) // 3 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 8, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ----xxx===------ // ++++++++ // => // --++++++++---- t.Run("case 23", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(8) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 8 require.NoError(t, s.Push(f2, 8, cb2)) // 8 - 11 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 11, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // --xxx---===--- // ++++++ // => // --xxx++++++---- t.Run("case 24", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 6, cb3)) // 6 - 12 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 12, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // --xxx---===---### // +++++++++ // => // --xxx+++++++++### t.Run("case 25", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(3) cb3, t3 := getFrameSorterTestCallback(t) f4 := getData(9) cb4, t4 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 15, cb3)) // 15 - 18 require.NoError(t, s.Push(f4, 6, cb4)) // 6 - 15 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f4, 15: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 18, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) require.False(t, t4.WasCalled()) }) // ----xxx------ // +++++++ // => // --+++++++--- t.Run("case 26", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(10) cb2, t2 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 8 require.NoError(t, s.Push(f2, 3, cb2)) // 3 - 13 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 13, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) }) // ---xxx====--- // ++++ // => // --+xxx====--- t.Run("case 27", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(4) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 10 require.NoError(t, s.Push(f3, 2, cb3)) // 2 - 6 checkQueue(t, s, map[protocol.ByteCount][]byte{ 2: f3[:1], 3: f1, 6: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 2}, {Start: 10, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 27, for long frames", func(t *testing.T) { s := newFrameSorter() const mult = protocol.MinStreamFrameSize f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(4 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6*mult, cb2)) // 6 - 10 require.NoError(t, s.Push(f3, 2*mult, cb3)) // 2 - 6 checkQueue(t, s, map[protocol.ByteCount][]byte{ 2 * mult: f3[:mult], 3 * mult: f1, 6 * mult: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 2 * mult}, {Start: 10 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx====--- // ++++++ // => // --+xxx====--- t.Run("case 28", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 10 require.NoError(t, s.Push(f3, 2, cb3)) // 2 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 2: f3[:1], 3: f1, 6: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 2}, {Start: 10, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 28, for long frames", func(t *testing.T) { s := newFrameSorter() const mult = protocol.MinStreamFrameSize f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6*mult, cb2)) // 6 - 10 require.NoError(t, s.Push(f3, 2*mult, cb3)) // 2 - 8 checkQueue(t, s, map[protocol.ByteCount][]byte{ 2 * mult: f3[:mult], 3 * mult: f1, 6 * mult: f2, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 2 * mult}, {Start: 10 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx===----- // +++++ // => // ---xxx+++++--- t.Run("case 29", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(5) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 9 require.NoError(t, s.Push(f3, 6, cb3)) // 6 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 11, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx===---- // ++++++ // => // ---xxx===++-- t.Run("case 30", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6, cb2)) // 6 - 9 require.NoError(t, s.Push(f3, 5, cb3)) // 5 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f2, 9: f3[4:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 11, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 30, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 2)) f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(6 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 6*mult, cb2)) // 6 - 9 require.NoError(t, s.Push(f3, 5*mult, cb3)) // 5 - 11 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f1, 6 * mult: f2, 9 * mult: f3[4*mult:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 11 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx---===----- // ++++++++++ // => // ---xxx++++++++--- t.Run("case 31", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(10) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 5, cb3)) // 5 - 15 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f3[1:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 15, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 31, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 9)) f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(10 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9*mult, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 5*mult, cb3)) // 5 - 15 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f1, 6 * mult: f3[mult:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 15 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx---===----- // +++++++++ // => // ---+++++++++--- t.Run("case 32", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(9) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 12 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f3, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 12, End: protocol.MaxByteCount}, }) require.True(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) // ---xxx---===###----- // ++++++++++++ // => // ---xxx++++++++++--- t.Run("case 33", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(3) cb3, t3 := getFrameSorterTestCallback(t) f4 := getData(12) cb4, t4 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 9, cb3)) // 12 - 15 require.NoError(t, s.Push(f4, 5, cb4)) // 5 - 17 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f1, 6: f4[1:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 17, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.True(t, t3.WasCalled()) require.True(t, t4.WasCalled()) }) t.Run("case 33, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 11)) f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(3 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(3 * mult) cb3, t3 := getFrameSorterTestCallback(t) f4 := getData(12 * mult) cb4, t4 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9*mult, cb2)) // 9 - 12 require.NoError(t, s.Push(f3, 9*mult, cb3)) // 12 - 15 require.NoError(t, s.Push(f4, 5*mult, cb4)) // 5 - 17 checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f1, 6 * mult: f4[mult:], }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 17 * mult, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.True(t, t3.WasCalled()) require.False(t, t4.WasCalled()) }) // ---xxx===---### // ++++++ // => // ---xxx++++++### t.Run("case 34", func(t *testing.T) { s := newFrameSorter() f1 := getData(5) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(5) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(10) cb3, t3 := getFrameSorterTestCallback(t) f4 := getData(5) cb4, t4 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 5, cb1)) // 5 - 10 require.NoError(t, s.Push(f2, 10, cb2)) // 10 - 15 require.NoError(t, s.Push(f4, 20, cb3)) // 20 - 25 require.NoError(t, s.Push(f3, 10, cb4)) // 10 - 20 checkQueue(t, s, map[protocol.ByteCount][]byte{ 5: f1, 10: f3, 20: f4, }) checkGaps(t, s, []byteInterval{ {Start: 0, End: 5}, {Start: 25, End: protocol.MaxByteCount}, }) require.False(t, t1.WasCalled()) require.True(t, t2.WasCalled()) require.False(t, t3.WasCalled()) require.False(t, t4.WasCalled()) }) // ---xxx---####--- // ++++++++ // => // ---++++++####--- t.Run("case 35", func(t *testing.T) { s := newFrameSorter() f1 := getData(3) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(8) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9, cb2)) // 9 - 13 require.NoError(t, s.Push(f3, 3, cb3)) // 3 - 11 checkGaps(t, s, []byteInterval{ {Start: 0, End: 3}, {Start: 13, End: protocol.MaxByteCount}, }) checkQueue(t, s, map[protocol.ByteCount][]byte{ 3: f3[:6], 9: f2, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.True(t, t3.WasCalled()) }) t.Run("case 35, for long frames", func(t *testing.T) { s := newFrameSorter() mult := protocol.ByteCount(math.Ceil(float64(protocol.MinStreamFrameSize) / 6)) f1 := getData(3 * mult) cb1, t1 := getFrameSorterTestCallback(t) f2 := getData(4 * mult) cb2, t2 := getFrameSorterTestCallback(t) f3 := getData(8 * mult) cb3, t3 := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f1, 3*mult, cb1)) // 3 - 6 require.NoError(t, s.Push(f2, 9*mult, cb2)) // 9 - 13 require.NoError(t, s.Push(f3, 3*mult, cb3)) // 3 - 11 checkGaps(t, s, []byteInterval{ {Start: 0, End: 3 * mult}, {Start: 13 * mult, End: protocol.MaxByteCount}, }) checkQueue(t, s, map[protocol.ByteCount][]byte{ 3 * mult: f3[:6*mult], 9 * mult: f2, }) require.True(t, t1.WasCalled()) require.False(t, t2.WasCalled()) require.False(t, t3.WasCalled()) }) } func TestFrameSorterTooManyGaps(t *testing.T) { s := newFrameSorter() for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ { require.NoError(t, s.Push([]byte("foobar"), protocol.ByteCount(i*7), nil)) } require.Equal(t, protocol.MaxStreamFrameSorterGaps, s.gaps.Len()) err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, nil) require.EqualError(t, err, "too many gaps in received data") } func TestFrameSorterRandomized(t *testing.T) { t.Run("short", func(t *testing.T) { testFrameSorterRandomized(t, 25, false, false) }) t.Run("long", func(t *testing.T) { testFrameSorterRandomized(t, 2*protocol.MinStreamFrameSize, false, false) }) t.Run("short, with duplicates", func(t *testing.T) { testFrameSorterRandomized(t, 25, true, false) }) t.Run("long, with duplicates", func(t *testing.T) { testFrameSorterRandomized(t, 2*protocol.MinStreamFrameSize, true, false) }) t.Run("short, with overlaps", func(t *testing.T) { testFrameSorterRandomized(t, 25, false, true) }) t.Run("long, with overlaps", func(t *testing.T) { testFrameSorterRandomized(t, 2*protocol.MinStreamFrameSize, false, true) }) } func testFrameSorterRandomized(t *testing.T, dataLen protocol.ByteCount, injectDuplicates, injectOverlaps bool) { type frame struct { offset protocol.ByteCount data []byte } const num = 1000 data := make([]byte, num*int(dataLen)) var seed [32]byte rand.Read(seed[:]) random := mrand.NewChaCha8(seed) random.Read(data) frames := make([]frame, num) for i := 0; i < num; i++ { b := make([]byte, dataLen) offset := i * int(dataLen) copy(b, data[offset:offset+int(dataLen)]) frames[i] = frame{ offset: protocol.ByteCount(i) * dataLen, data: b, } } mrand.Shuffle(len(frames), func(i, j int) { frames[i], frames[j] = frames[j], frames[i] }) s := newFrameSorter() var callbacks []callbackTracker for _, f := range frames { cb, tr := getFrameSorterTestCallback(t) require.NoError(t, s.Push(f.data, f.offset, cb)) callbacks = append(callbacks, tr) } if injectDuplicates { for i := 0; i < num/10; i++ { cb, tr := getFrameSorterTestCallback(t) df := frames[mrand.IntN(len(frames))] require.NoError(t, s.Push(df.data, df.offset, cb)) callbacks = append(callbacks, tr) } } if injectOverlaps { finalOffset := num * dataLen for i := 0; i < num/3; i++ { cb, tr := getFrameSorterTestCallback(t) startOffset := protocol.ByteCount(mrand.IntN(int(finalOffset))) endOffset := startOffset + protocol.ByteCount(mrand.IntN(int(finalOffset-startOffset))) require.NoError(t, s.Push(data[startOffset:endOffset], startOffset, cb)) callbacks = append(callbacks, tr) } } require.Equal(t, 1, s.gaps.Len()) require.Equal(t, byteInterval{Start: num * dataLen, End: protocol.MaxByteCount}, s.gaps.Front().Value) // read all data var read []byte for { offset, b, cb := s.Pop() if b == nil { break } require.Equal(t, offset, protocol.ByteCount(len(read))) read = append(read, b...) if cb != nil { cb() } } require.Equal(t, data, read) require.False(t, s.HasMoreData()) for _, cb := range callbacks { require.True(t, cb.WasCalled()) } } func TestFrameSorterPeek(t *testing.T) { s := newFrameSorter() require.NoError(t, s.Peek(1337, []byte{})) // empty peek is a no-op require.ErrorIs(t, s.Peek(0, []byte{0, 1, 2, 3, 4}), errTooLittleData) require.NoError(t, s.Push([]byte("foobar"), 0, nil)) // peek partial frame p := make([]byte, 3) require.NoError(t, s.Peek(0, p)) require.Equal(t, []byte("foo"), p) // peek entire frame p = make([]byte, 6) require.NoError(t, s.Peek(0, p)) require.Equal(t, []byte("foobar"), p) // peek more than available p = make([]byte, 10) require.ErrorIs(t, s.Peek(0, p), errTooLittleData) // peek at offset where no entry exists p = make([]byte, 3) require.ErrorIs(t, s.Peek(3, p), errTooLittleData) // peek across multiple frames s.Push([]byte("baz"), 6, nil) p = make([]byte, 9) require.NoError(t, s.Peek(0, p)) require.Equal(t, []byte("foobarbaz"), p) // peek starting from second frame p = make([]byte, 3) require.NoError(t, s.Peek(6, p)) require.Equal(t, []byte("baz"), p) // peeking across gaps doesn't work s.Push([]byte("qux"), 10, nil) p = make([]byte, 10) require.ErrorIs(t, s.Peek(0, p), errTooLittleData) } quic-go-0.59.0/framer.go000066400000000000000000000221321513066070600147600ustar00rootroot00000000000000package quic import ( "slices" "sync" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/quicvarint" ) const ( maxPathResponses = 256 maxControlFrames = 16 << 10 ) // This is the largest possible size of a stream-related control frame // (which is the RESET_STREAM frame). const maxStreamControlFrameSize = 25 type streamFrameGetter interface { popStreamFrame(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) } type streamControlFrameGetter interface { getControlFrame(monotime.Time) (_ ackhandler.Frame, ok, hasMore bool) } type framer struct { mutex sync.Mutex activeStreams map[protocol.StreamID]streamFrameGetter streamQueue ringbuffer.RingBuffer[protocol.StreamID] streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter controlFrameMutex sync.Mutex controlFrames []wire.Frame pathResponses []*wire.PathResponseFrame connFlowController flowcontrol.ConnectionFlowController queuedTooManyControlFrames bool } func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer { return &framer{ activeStreams: make(map[protocol.StreamID]streamFrameGetter), streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter), connFlowController: connFlowController, } } func (f *framer) HasData() bool { f.mutex.Lock() hasData := !f.streamQueue.Empty() f.mutex.Unlock() if hasData { return true } f.controlFrameMutex.Lock() defer f.controlFrameMutex.Unlock() return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0 } func (f *framer) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Lock() defer f.controlFrameMutex.Unlock() if pr, ok := frame.(*wire.PathResponseFrame); ok { // Only queue up to maxPathResponses PATH_RESPONSE frames. // This limit should be high enough to never be hit in practice, // unless the peer is doing something malicious. if len(f.pathResponses) >= maxPathResponses { return } f.pathResponses = append(f.pathResponses, pr) return } // This is a hack. if len(f.controlFrames) >= maxControlFrames { f.queuedTooManyControlFrames = true return } f.controlFrames = append(f.controlFrames, frame) } func (f *framer) Append( frames []ackhandler.Frame, streamFrames []ackhandler.StreamFrame, maxLen protocol.ByteCount, now monotime.Time, v protocol.Version, ) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) { f.controlFrameMutex.Lock() frames, controlFrameLen := f.appendControlFrames(frames, maxLen, now, v) maxLen -= controlFrameLen var lastFrame ackhandler.StreamFrame var streamFrameLen protocol.ByteCount f.mutex.Lock() // pop STREAM frames, until less than 128 bytes are left in the packet numActiveStreams := f.streamQueue.Len() for i := 0; i < numActiveStreams; i++ { if protocol.MinStreamFrameSize > maxLen { break } sf, blocked := f.getNextStreamFrame(maxLen, v) if sf.Frame != nil { streamFrames = append(streamFrames, sf) maxLen -= sf.Frame.Length(v) lastFrame = sf streamFrameLen += sf.Frame.Length(v) } // If the stream just became blocked on stream flow control, attempt to pack the // STREAM_DATA_BLOCKED into the same packet. if blocked != nil { l := blocked.Length(v) // In case it doesn't fit, queue it for the next packet. if maxLen < l { f.controlFrames = append(f.controlFrames, blocked) break } frames = append(frames, ackhandler.Frame{Frame: blocked}) maxLen -= l controlFrameLen += l } } // The only way to become blocked on connection-level flow control is by sending STREAM frames. if isBlocked, offset := f.connFlowController.IsNewlyBlocked(); isBlocked { blocked := &wire.DataBlockedFrame{MaximumData: offset} l := blocked.Length(v) // In case it doesn't fit, queue it for the next packet. if maxLen >= l { frames = append(frames, ackhandler.Frame{Frame: blocked}) controlFrameLen += l } else { f.controlFrames = append(f.controlFrames, blocked) } } f.mutex.Unlock() f.controlFrameMutex.Unlock() if lastFrame.Frame != nil { // account for the smaller size of the last STREAM frame streamFrameLen -= lastFrame.Frame.Length(v) lastFrame.Frame.DataLenPresent = false streamFrameLen += lastFrame.Frame.Length(v) } return frames, streamFrames, controlFrameLen + streamFrameLen } func (f *framer) appendControlFrames( frames []ackhandler.Frame, maxLen protocol.ByteCount, now monotime.Time, v protocol.Version, ) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount // add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet if len(f.pathResponses) > 0 { frame := f.pathResponses[0] frameLen := frame.Length(v) if frameLen <= maxLen { frames = append(frames, ackhandler.Frame{Frame: frame}) length += frameLen f.pathResponses = f.pathResponses[1:] } } // add stream-related control frames for id, str := range f.streamsWithControlFrames { start: remainingLen := maxLen - length if remainingLen <= maxStreamControlFrameSize { break } fr, ok, hasMore := str.getControlFrame(now) if !hasMore { delete(f.streamsWithControlFrames, id) } if !ok { continue } frames = append(frames, fr) length += fr.Frame.Length(v) if hasMore { // It is rare that a stream has more than one control frame to queue. // We don't want to spawn another loop for just to cover that case. goto start } } for len(f.controlFrames) > 0 { frame := f.controlFrames[len(f.controlFrames)-1] frameLen := frame.Length(v) if length+frameLen > maxLen { break } frames = append(frames, ackhandler.Frame{Frame: frame}) length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } return frames, length } // QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length. // This is a hack. // It is easier to implement than propagating an error return value in QueueControlFrame. // The correct solution would be to queue frames with their respective structs. // See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames. func (f *framer) QueuedTooManyControlFrames() bool { return f.queuedTooManyControlFrames } func (f *framer) AddActiveStream(id protocol.StreamID, str streamFrameGetter) { f.mutex.Lock() if _, ok := f.activeStreams[id]; !ok { f.streamQueue.PushBack(id) f.activeStreams[id] = str } f.mutex.Unlock() } func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) { f.controlFrameMutex.Lock() if _, ok := f.streamsWithControlFrames[id]; !ok { f.streamsWithControlFrames[id] = str } f.controlFrameMutex.Unlock() } // RemoveActiveStream is called when a stream completes. func (f *framer) RemoveActiveStream(id protocol.StreamID) { f.mutex.Lock() delete(f.activeStreams, id) // We don't delete the stream from the streamQueue, // since we'd have to iterate over the ringbuffer. // Instead, we check if the stream is still in activeStreams when appending STREAM frames. f.mutex.Unlock() } func (f *framer) getNextStreamFrame(maxLen protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame) { id := f.streamQueue.PopFront() // This should never return an error. Better check it anyway. // The stream will only be in the streamQueue, if it enqueued itself there. str, ok := f.activeStreams[id] // The stream might have been removed after being enqueued. if !ok { return ackhandler.StreamFrame{}, nil } // For the last STREAM frame, we'll remove the DataLen field later. // Therefore, we can pretend to have more bytes available when popping // the STREAM frame (which will always have the DataLen set). maxLen += protocol.ByteCount(quicvarint.Len(uint64(maxLen))) frame, blocked, hasMoreData := str.popStreamFrame(maxLen, v) if hasMoreData { // put the stream back in the queue (at the end) f.streamQueue.PushBack(id) } else { // no more data to send. Stream is not active delete(f.activeStreams, id) } // Note that the frame.Frame can be nil: // * if the stream was canceled after it said it had data // * the remaining size doesn't allow us to add another STREAM frame return frame, blocked } func (f *framer) Handle0RTTRejection() { f.mutex.Lock() defer f.mutex.Unlock() f.controlFrameMutex.Lock() defer f.controlFrameMutex.Unlock() f.streamQueue.Clear() for id := range f.activeStreams { delete(f.activeStreams, id) } var j int for i, frame := range f.controlFrames { switch frame.(type) { case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame, *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame: continue default: f.controlFrames[j] = f.controlFrames[i] j++ } } f.controlFrames = slices.Delete(f.controlFrames, j, len(f.controlFrames)) } quic-go-0.59.0/framer_test.go000066400000000000000000000506041513066070600160240ustar00rootroot00000000000000package quic import ( "bytes" "encoding/binary" "math/rand/v2" "testing" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func TestFramerControlFrames(t *testing.T) { pc := &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 6, 7, 8}} msf := &wire.MaxStreamsFrame{MaxStreamNum: 0x1337} framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) require.False(t, framer.HasData()) framer.QueueControlFrame(pc) require.True(t, framer.HasData()) framer.QueueControlFrame(msf) frames, streamFrames, length := framer.Append( []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1, ) require.Len(t, frames, 3) require.Empty(t, streamFrames) require.Contains(t, frames, ackhandler.Frame{Frame: &wire.PingFrame{}}) require.Contains(t, frames, ackhandler.Frame{Frame: pc}) require.Contains(t, frames, ackhandler.Frame{Frame: msf}) require.Equal(t, length, pc.Length(protocol.Version1)+msf.Length(protocol.Version1)) require.False(t, framer.HasData()) } func TestFramerControlFrameSizing(t *testing.T) { const maxSize = protocol.ByteCount(1000) bf := &wire.DataBlockedFrame{MaximumData: 0x1337} bfLen := bf.Length(protocol.Version1) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) numFrames := int(maxSize / bfLen) // max number of frames that fit into maxSize for i := 0; i < numFrames+1; i++ { framer.QueueControlFrame(bf) } frames, _, length := framer.Append(nil, nil, maxSize, monotime.Now(), protocol.Version1) require.Len(t, frames, numFrames) require.Greater(t, length, maxSize-bfLen) // now make sure that the last frame is also added frames, _, length = framer.Append(nil, nil, maxSize, monotime.Now(), protocol.Version1) require.Len(t, frames, 1) require.Equal(t, length, bfLen) } func TestFramerStreamControlFrames(t *testing.T) { const streamID = protocol.StreamID(10) ping := &wire.PingFrame{} mdf1 := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337} mdf2 := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1338} framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) framer.QueueControlFrame(ping) str := NewMockStreamControlFrameGetter(gomock.NewController(t)) framer.AddStreamWithControlFrames(streamID, str) now := monotime.Now() str.EXPECT().getControlFrame(now).Return(ackhandler.Frame{Frame: mdf1}, true, true) str.EXPECT().getControlFrame(now).Return(ackhandler.Frame{Frame: mdf2}, true, false) frames, streamFrames, l := framer.Append(nil, nil, protocol.MaxByteCount, now, protocol.Version1) require.Len(t, frames, 3) require.Empty(t, streamFrames) require.Equal(t, mdf1, frames[0].Frame) require.Equal(t, mdf2, frames[1].Frame) require.Equal(t, ping, frames[2].Frame) require.Equal(t, ping.Length(protocol.Version1)+mdf1.Length(protocol.Version1)+mdf2.Length(protocol.Version1), l) } // If there are less than 25 bytes left, no more stream-related control frames are enqueued. // This avoids dequeueing a frame from the stream that would be too large to fit into the packet. func TestFramerStreamControlFramesSizing(t *testing.T) { mdf1 := &wire.MaxStreamDataFrame{MaximumStreamData: 1337} str := NewMockStreamControlFrameGetter(gomock.NewController(t)) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) framer.AddStreamWithControlFrames(10, str) str.EXPECT().getControlFrame(gomock.Any()).Return(ackhandler.Frame{Frame: mdf1}, true, true).AnyTimes() frames, _, l := framer.Append(nil, nil, 100, monotime.Now(), protocol.Version1) require.Equal(t, protocol.ByteCount(len(frames))*mdf1.Length(protocol.Version1), l) require.Greater(t, l, protocol.ByteCount(100-maxStreamControlFrameSize)) require.LessOrEqual(t, l, protocol.ByteCount(100)) } func TestFramerStreamDataBlocked(t *testing.T) { t.Run("small STREAM frame", func(t *testing.T) { testFramerStreamDataBlocked(t, true) }) t.Run("large STREAM frame", func(t *testing.T) { testFramerStreamDataBlocked(t, false) }) } // If the stream becomes blocked on stream flow control, we attempt to pack the STREAM_DATA_BLOCKED // into the same packet. // However, there's the pathological case, where the STREAM frame and the STREAM_DATA_BLOCKED frame // don't fit into the same packet. In that case, the STREAM_DATA_BLOCKED frame is queued and sent // in the next packet. func testFramerStreamDataBlocked(t *testing.T, fits bool) { const streamID = 5 str := NewMockStreamFrameGetter(gomock.NewController(t)) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) framer.AddActiveStream(streamID, str) str.EXPECT().popStreamFrame(gomock.Any(), gomock.Any()).DoAndReturn( func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { data := []byte("foobar") if !fits { // Leave 3 bytes in the packet. // This is not enough to fit in the STREAM_DATA_BLOCKED frame. data = make([]byte, size-3) } f := &wire.StreamFrame{StreamID: streamID, DataLenPresent: true, Data: data} blocked := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: f.DataLen()} if !fits { require.Greater(t, blocked.Length(protocol.Version1), protocol.ByteCount(3)) } return ackhandler.StreamFrame{Frame: f}, blocked, false }, ) const maxSize protocol.ByteCount = 1000 frames, streamFrames, l := framer.Append(nil, nil, maxSize, monotime.Now(), protocol.Version1) require.Len(t, streamFrames, 1) dataLen := streamFrames[0].Frame.DataLen() if fits { require.Len(t, frames, 1) require.Equal(t, &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: dataLen}, frames[0].Frame) } else { require.Equal(t, streamFrames[0].Frame.Length(protocol.Version1), l) require.Empty(t, frames) frames, streamFrames, l2 := framer.Append(nil, nil, maxSize, monotime.Now(), protocol.Version1) require.Greater(t, l+l2, maxSize) require.Empty(t, streamFrames) require.Len(t, frames, 1) require.Equal(t, &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: dataLen}, frames[0].Frame) } } func TestFramerDataBlocked(t *testing.T) { t.Run("small STREAM frame", func(t *testing.T) { testFramerDataBlocked(t, true) }) t.Run("large STREAM frame", func(t *testing.T) { testFramerDataBlocked(t, false) }) } // If the stream becomes blocked on connection flow control, we attempt to pack the // DATA_BLOCKED frame into the same packet. // However, there's the pathological case, where the STREAM frame and the DATA_BLOCKED frame // don't fit into the same packet. In that case, the DATA_BLOCKED frame is queued and sent // in the next packet. func testFramerDataBlocked(t *testing.T, fits bool) { const streamID = 5 const offset = 100 fc := flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil) fc.UpdateSendWindow(offset) fc.AddBytesSent(offset) str := NewMockStreamFrameGetter(gomock.NewController(t)) framer := newFramer(fc) framer.AddActiveStream(streamID, str) str.EXPECT().popStreamFrame(gomock.Any(), gomock.Any()).DoAndReturn( func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { data := []byte("foobar") if !fits { // Leave 2 bytes in the packet. // This is not enough to fit in the DATA_BLOCKED frame. data = make([]byte, size-2) } f := &wire.StreamFrame{StreamID: streamID, DataLenPresent: true, Data: data} return ackhandler.StreamFrame{Frame: f}, nil, false }, ) const maxSize protocol.ByteCount = 1000 frames, streamFrames, l := framer.Append(nil, nil, maxSize, monotime.Now(), protocol.Version1) require.Len(t, streamFrames, 1) if fits { require.Len(t, frames, 1) require.Equal(t, &wire.DataBlockedFrame{MaximumData: offset}, frames[0].Frame) } else { require.Equal(t, streamFrames[0].Frame.Length(protocol.Version1), l) require.Empty(t, frames) frames, streamFrames, l2 := framer.Append(nil, nil, maxSize, monotime.Now(), protocol.Version1) require.Greater(t, l+l2, maxSize) require.Empty(t, streamFrames) require.Len(t, frames, 1) require.Equal(t, &wire.DataBlockedFrame{MaximumData: offset}, frames[0].Frame) } } func TestFramerDetectsFrameDoS(t *testing.T) { framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) for i := 0; i < maxControlFrames-1; i++ { framer.QueueControlFrame(&wire.PingFrame{}) framer.QueueControlFrame(&wire.PingFrame{}) require.False(t, framer.QueuedTooManyControlFrames()) frames, _, _ := framer.Append([]ackhandler.Frame{}, nil, 1, monotime.Now(), protocol.Version1) require.Len(t, frames, 1) require.Len(t, framer.controlFrames, i+1) } framer.QueueControlFrame(&wire.PingFrame{}) require.False(t, framer.QueuedTooManyControlFrames()) require.Len(t, framer.controlFrames, maxControlFrames) framer.QueueControlFrame(&wire.PingFrame{}) require.True(t, framer.QueuedTooManyControlFrames()) require.Len(t, framer.controlFrames, maxControlFrames) } func TestFramerDetectsFramePathResponseDoS(t *testing.T) { framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) var pathResponses []*wire.PathResponseFrame for i := 0; i < 2*maxPathResponses; i++ { var f wire.PathResponseFrame binary.BigEndian.PutUint64(f.Data[:], rand.Uint64()) pathResponses = append(pathResponses, &f) framer.QueueControlFrame(&f) } for i := 0; i < maxPathResponses; i++ { require.True(t, framer.HasData()) frames, _, length := framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Len(t, frames, 1) require.Equal(t, pathResponses[i], frames[0].Frame) require.Equal(t, pathResponses[i].Length(protocol.Version1), length) } require.False(t, framer.HasData()) frames, _, length := framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Empty(t, frames) require.Zero(t, length) } func TestFramerPacksSinglePathResponsePerPacket(t *testing.T) { framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) f1 := &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} f2 := &wire.PathResponseFrame{Data: [8]byte{2, 3, 4, 5, 6, 7, 8, 9}} cf1 := &wire.DataBlockedFrame{MaximumData: 1337} cf2 := &wire.HandshakeDoneFrame{} framer.QueueControlFrame(f1) framer.QueueControlFrame(f2) framer.QueueControlFrame(cf1) framer.QueueControlFrame(cf2) // the first packet should contain a single PATH_RESPONSE frame, but all the other control frames frames, _, _ := framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Len(t, frames, 3) require.Equal(t, f1, frames[0].Frame) require.Contains(t, []wire.Frame{frames[1].Frame, frames[2].Frame}, cf1) require.Contains(t, []wire.Frame{frames[1].Frame, frames[2].Frame}, cf2) // the second packet should contain the other PATH_RESPONSE frame require.True(t, framer.HasData()) frames, _, _ = framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Len(t, frames, 1) require.Equal(t, f2, frames[0].Frame) require.False(t, framer.HasData()) } func TestFramerAppendStreamFrames(t *testing.T) { const ( str1ID = protocol.StreamID(42) str2ID = protocol.StreamID(43) ) f1 := &wire.StreamFrame{StreamID: str1ID, Data: []byte("foo"), DataLenPresent: true} f2 := &wire.StreamFrame{StreamID: str2ID, Data: []byte("bar"), DataLenPresent: true} totalLen := f1.Length(protocol.Version1) + f2.Length(protocol.Version1) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) require.False(t, framer.HasData()) // no frames added yet controlFrames, fs, length := framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Empty(t, controlFrames) require.Empty(t, fs) require.Zero(t, length) // add two streams mockCtrl := gomock.NewController(t) str1 := NewMockStreamFrameGetter(mockCtrl) str1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, nil, true) str2 := NewMockStreamFrameGetter(mockCtrl) str2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, nil, false) framer.AddActiveStream(str1ID, str1) framer.AddActiveStream(str1ID, str1) // duplicate calls are ok (they're no-ops) framer.AddActiveStream(str2ID, str2) require.True(t, framer.HasData()) // Even though the first stream claimed to have more data, // we only dequeue a single STREAM frame per call of AppendStreamFrames. f0 := ackhandler.StreamFrame{Frame: &wire.StreamFrame{StreamID: 9999}} controlFrames, fs, length = framer.Append([]ackhandler.Frame{}, []ackhandler.StreamFrame{f0}, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Empty(t, controlFrames) require.Len(t, fs, 3) require.Equal(t, f0, fs[0]) require.Equal(t, str1ID, fs[1].Frame.StreamID) require.Equal(t, []byte("foo"), fs[1].Frame.Data) // since two STREAM frames are sent, the DataLenPresent flag is set on the first frame require.True(t, fs[1].Frame.DataLenPresent) require.Equal(t, str2ID, fs[2].Frame.StreamID) require.Equal(t, []byte("bar"), fs[2].Frame.Data) // the last frame doesn't have the DataLenPresent flag set require.False(t, fs[2].Frame.DataLenPresent) require.Equal(t, fs[1].Frame.Length(protocol.Version1)+fs[2].Frame.Length(protocol.Version1), length) require.Less(t, length, totalLen) // unsetting DataLenPresent on the last frame reduces the length require.True(t, framer.HasData()) // the stream claimed to have more data... // ... but it actually doesn't str1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{}, nil, false) _, fs, length = framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Empty(t, fs) require.Zero(t, length) require.False(t, framer.HasData()) } func TestFramerRemoveActiveStream(t *testing.T) { const id = protocol.StreamID(42) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) require.False(t, framer.HasData()) framer.AddActiveStream(id, NewMockStreamFrameGetter(gomock.NewController(t))) require.True(t, framer.HasData()) framer.RemoveActiveStream(id) // no calls will be issued to the mock stream // we can't assert on framer.HasData here, since it's not removed from the ringbuffer _, frames, _ := framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Empty(t, frames) require.False(t, framer.HasData()) } func TestFramerMinStreamFrameSize(t *testing.T) { const id = protocol.StreamID(42) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) str := NewMockStreamFrameGetter(gomock.NewController(t)) framer.AddActiveStream(id, str) require.True(t, framer.HasData()) // don't pop frames smaller than the minimum STREAM frame size _, frames, _ := framer.Append(nil, nil, protocol.MinStreamFrameSize-1, monotime.Now(), protocol.Version1) require.Empty(t, frames) // pop frames of the minimum size str.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn( func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { f := &wire.StreamFrame{StreamID: id, DataLenPresent: true} f.Data = make([]byte, f.MaxDataLen(protocol.MinStreamFrameSize, v)) return ackhandler.StreamFrame{Frame: f}, nil, false }, ) _, frames, _ = framer.Append(nil, nil, protocol.MinStreamFrameSize, monotime.Now(), protocol.Version1) require.Len(t, frames, 1) // unsetting DataLenPresent on the last frame reduced the size slightly beyond the minimum size require.Equal(t, protocol.MinStreamFrameSize-2, frames[0].Frame.Length(protocol.Version1)) } func TestFramerMinStreamFrameSizeMultipleStreamFrames(t *testing.T) { const id = protocol.StreamID(42) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) str := NewMockStreamFrameGetter(gomock.NewController(t)) framer.AddActiveStream(id, str) // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size f := &wire.StreamFrame{ StreamID: id, Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), DataLenPresent: true, } str.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f}, nil, false) framer.AddActiveStream(id, str) _, fs, length := framer.Append(nil, nil, 500, monotime.Now(), protocol.Version1) require.Len(t, fs, 1) require.Equal(t, f, fs[0].Frame) require.Equal(t, f.Length(protocol.Version1), length) } func TestFramerFillPacketOneStream(t *testing.T) { const id = protocol.StreamID(42) str := NewMockStreamFrameGetter(gomock.NewController(t)) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) for i := protocol.MinStreamFrameSize; i < 2000; i++ { str.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn( func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { f := &wire.StreamFrame{ StreamID: id, DataLenPresent: true, } f.Data = make([]byte, f.MaxDataLen(size, v)) require.Equal(t, size, f.Length(protocol.Version1)) return ackhandler.StreamFrame{Frame: f}, nil, false }, ) framer.AddActiveStream(id, str) _, frames, _ := framer.Append(nil, nil, i, monotime.Now(), protocol.Version1) require.Len(t, frames, 1) require.False(t, frames[0].Frame.DataLenPresent) // make sure the entire space was filled up require.Equal(t, i, frames[0].Frame.Length(protocol.Version1)) } } func TestFramerFillPacketMultipleStreams(t *testing.T) { const ( id1 = protocol.StreamID(1000) id2 = protocol.StreamID(11) ) mockCtrl := gomock.NewController(t) stream1 := NewMockStreamFrameGetter(mockCtrl) stream2 := NewMockStreamFrameGetter(mockCtrl) framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ { stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn( func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { f := &wire.StreamFrame{StreamID: id1, DataLenPresent: true} f.Data = make([]byte, f.MaxDataLen(protocol.MinStreamFrameSize, v)) return ackhandler.StreamFrame{Frame: f}, nil, false }, ) stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn( func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { f := &wire.StreamFrame{StreamID: id2, DataLenPresent: true} f.Data = make([]byte, f.MaxDataLen(size, v)) require.Equal(t, size, f.Length(protocol.Version1)) return ackhandler.StreamFrame{Frame: f}, nil, false }, ) framer.AddActiveStream(id1, stream1) framer.AddActiveStream(id2, stream2) _, frames, _ := framer.Append(nil, nil, i, monotime.Now(), protocol.Version1) require.Len(t, frames, 2) require.True(t, frames[0].Frame.DataLenPresent) require.False(t, frames[1].Frame.DataLenPresent) require.Equal(t, i, frames[0].Frame.Length(protocol.Version1)+frames[1].Frame.Length(protocol.Version1)) } } func TestFramer0RTTRejection(t *testing.T) { ncid := &wire.NewConnectionIDFrame{ SequenceNumber: 10, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), } ping := &wire.PingFrame{} pc := &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 6, 7, 8}} framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil)) framer.QueueControlFrame(ncid) framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: 1337}) framer.QueueControlFrame(&wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1337}) framer.QueueControlFrame(ping) framer.QueueControlFrame(&wire.StreamsBlockedFrame{StreamLimit: 13}) framer.QueueControlFrame(pc) framer.AddActiveStream(10, NewMockStreamFrameGetter(gomock.NewController(t))) framer.Handle0RTTRejection() controlFrames, streamFrames, _ := framer.Append(nil, nil, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.Empty(t, streamFrames) require.Len(t, controlFrames, 3) require.Contains(t, controlFrames, ackhandler.Frame{Frame: pc}) require.Contains(t, controlFrames, ackhandler.Frame{Frame: ping}) require.Contains(t, controlFrames, ackhandler.Frame{Frame: ncid}) } quic-go-0.59.0/fuzzing/000077500000000000000000000000001513066070600146515ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/frames/000077500000000000000000000000001513066070600161265ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/frames/cmd/000077500000000000000000000000001513066070600166715ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/frames/cmd/corpus.go000066400000000000000000000173731513066070600205460ustar00rootroot00000000000000package main import ( "crypto/rand" "log" mrand "math/rand/v2" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/fuzzing/internal/helper" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) const version = protocol.Version1 func getRandomData(l int) []byte { b := make([]byte, l) rand.Read(b) return b } func getRandomNumber() uint64 { switch 1 << uint8(mrand.IntN(3)) { case 1: return mrand.Uint64N(64) case 2: return mrand.Uint64N(16384) case 4: return mrand.Uint64N(1073741824) case 8: return mrand.Uint64N(4611686018427387904) default: panic("unexpected length") } } func getRandomNumberLowerOrEqual(target uint64) uint64 { if target == 0 { return 0 } return mrand.Uint64N(target) } // returns a *maximum* number of num ACK ranges func getAckRanges(num int) []wire.AckRange { prevSmallest := mrand.Uint64N(4611686018427387904) ranges := make([]wire.AckRange, 0, num) for range num { if prevSmallest <= 2 { break } largest := getRandomNumberLowerOrEqual(prevSmallest - 2) smallest := getRandomNumberLowerOrEqual(largest) ranges = append(ranges, wire.AckRange{ Smallest: protocol.PacketNumber(smallest), Largest: protocol.PacketNumber(largest), }) prevSmallest = smallest } return ranges } func getFrames() []wire.Frame { frames := []wire.Frame{ &wire.StreamFrame{ // STREAM frame at 0 offset, with FIN bit StreamID: protocol.StreamID(getRandomNumber()), Fin: true, }, &wire.StreamFrame{ // STREAM frame at 0 offset, with data and FIN bit StreamID: protocol.StreamID(getRandomNumber()), Fin: true, Data: getRandomData(100), }, &wire.StreamFrame{ // STREAM frame at non-zero offset, with data StreamID: protocol.StreamID(getRandomNumber()), Offset: protocol.ByteCount(getRandomNumber()), Data: getRandomData(50), }, &wire.StreamFrame{ // STREAM frame at non-zero offset, with data and FIN bit StreamID: protocol.StreamID(getRandomNumber()), Offset: protocol.ByteCount(getRandomNumber()), Data: getRandomData(50), Fin: true, }, &wire.StreamFrame{ // STREAM frame at non-zero offset, with data and FIN bit. Long enough to use the buffer. StreamID: protocol.StreamID(getRandomNumber()), Offset: protocol.ByteCount(getRandomNumber()), Data: getRandomData(2 * protocol.MinStreamFrameBufferSize), Fin: true, }, &wire.StreamFrame{ // STREAM frame at maximum offset, with FIN bit StreamID: protocol.StreamID(getRandomNumber()), Offset: protocol.MaxByteCount - 5, Data: getRandomData(5), Fin: true, }, &wire.StreamFrame{ // STREAM frame with data at maximum offset StreamID: protocol.StreamID(getRandomNumber()), Offset: protocol.MaxByteCount, Data: getRandomData(10), }, &wire.AckFrame{ AckRanges: getAckRanges(1), DelayTime: time.Duration(getRandomNumber()), }, &wire.AckFrame{ AckRanges: getAckRanges(5), DelayTime: time.Duration(getRandomNumber()), }, &wire.AckFrame{ AckRanges: getAckRanges(300), DelayTime: time.Duration(getRandomNumber()), }, &wire.AckFrame{ AckRanges: getAckRanges(3), DelayTime: time.Duration(getRandomNumber()), ECT0: getRandomNumber(), ECT1: getRandomNumber(), ECNCE: getRandomNumber(), }, &wire.PingFrame{}, &wire.ResetStreamFrame{ StreamID: protocol.StreamID(getRandomNumber()), ErrorCode: quic.StreamErrorCode(getRandomNumber()), FinalSize: protocol.ByteCount(getRandomNumber()), }, &wire.ResetStreamFrame{ // at maximum offset StreamID: protocol.StreamID(getRandomNumber()), ErrorCode: quic.StreamErrorCode(getRandomNumber()), FinalSize: protocol.MaxByteCount, }, &wire.StopSendingFrame{ StreamID: protocol.StreamID(getRandomNumber()), ErrorCode: quic.StreamErrorCode(getRandomNumber()), }, &wire.CryptoFrame{ Data: getRandomData(100), }, &wire.CryptoFrame{ Offset: protocol.ByteCount(getRandomNumber()), Data: getRandomData(50), }, &wire.NewTokenFrame{ Token: getRandomData(10), }, &wire.MaxDataFrame{ MaximumData: protocol.ByteCount(getRandomNumber()), }, &wire.MaxDataFrame{ MaximumData: protocol.MaxByteCount, }, &wire.MaxStreamDataFrame{ StreamID: protocol.StreamID(getRandomNumber()), MaximumStreamData: protocol.ByteCount(getRandomNumber()), }, &wire.MaxStreamDataFrame{ StreamID: protocol.StreamID(getRandomNumber()), MaximumStreamData: protocol.MaxByteCount, }, &wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, MaxStreamNum: protocol.StreamNum(getRandomNumber()), }, &wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.StreamNum(getRandomNumber()), }, &wire.DataBlockedFrame{ MaximumData: protocol.ByteCount(getRandomNumber()), }, &wire.DataBlockedFrame{ MaximumData: protocol.MaxByteCount, }, &wire.StreamDataBlockedFrame{ StreamID: protocol.StreamID(getRandomNumber()), MaximumStreamData: protocol.ByteCount(getRandomNumber()), }, &wire.StreamDataBlockedFrame{ StreamID: protocol.StreamID(getRandomNumber()), MaximumStreamData: protocol.MaxByteCount, }, &wire.StreamsBlockedFrame{ Type: protocol.StreamTypeUni, StreamLimit: protocol.StreamNum(getRandomNumber()), }, &wire.StreamsBlockedFrame{ Type: protocol.StreamTypeBidi, StreamLimit: protocol.StreamNum(getRandomNumber()), }, &wire.RetireConnectionIDFrame{ SequenceNumber: getRandomNumber(), }, &wire.ConnectionCloseFrame{ // QUIC error with empty reason IsApplicationError: false, ErrorCode: getRandomNumber(), ReasonPhrase: "", }, &wire.ConnectionCloseFrame{ // QUIC error with reason IsApplicationError: false, // TODO: add frame type ErrorCode: getRandomNumber(), ReasonPhrase: string(getRandomData(100)), }, &wire.ConnectionCloseFrame{ // application error with empty reason IsApplicationError: true, ErrorCode: getRandomNumber(), ReasonPhrase: "", }, &wire.ConnectionCloseFrame{ // application error with reason IsApplicationError: true, ErrorCode: getRandomNumber(), ReasonPhrase: string(getRandomData(100)), }, } seq1 := getRandomNumber() seq2 := getRandomNumber() var token1, token2 protocol.StatelessResetToken copy(token1[:], getRandomData(16)) copy(token2[:], getRandomData(16)) frames = append(frames, []wire.Frame{ &wire.NewConnectionIDFrame{ SequenceNumber: seq1, RetirePriorTo: seq1 / 2, ConnectionID: protocol.ParseConnectionID(getRandomData(4)), StatelessResetToken: token1, }, &wire.NewConnectionIDFrame{ SequenceNumber: seq2, RetirePriorTo: seq2, ConnectionID: protocol.ParseConnectionID(getRandomData(17)), StatelessResetToken: token2, }, }...) var data1 [8]byte copy(data1[:], getRandomData(8)) frames = append(frames, &wire.PathChallengeFrame{ Data: data1, }) var data2 [8]byte copy(data2[:], getRandomData(8)) frames = append(frames, &wire.PathResponseFrame{ Data: data2, }) return frames } func main() { for _, f := range getFrames() { b, err := f.Append(nil, version) if err != nil { log.Fatal(err) } if err := helper.WriteCorpusFileWithPrefix("corpus", b, 1); err != nil { log.Fatal(err) } } for range 30 { frames := getFrames() var b []byte for range mrand.IntN(30) + 2 { if mrand.IntN(10) == 0 { // write a PADDING frame b = append(b, 0) } f := frames[mrand.IntN(len(frames))] var err error b, err = f.Append(b, version) if err != nil { log.Fatal(err) } if mrand.IntN(10) == 0 { // write a PADDING frame b = append(b, 0) } } if err := helper.WriteCorpusFileWithPrefix("corpus", b, 1); err != nil { log.Fatal(err) } } } quic-go-0.59.0/fuzzing/frames/fuzz.go000066400000000000000000000106421513066070600174560ustar00rootroot00000000000000package frames import ( "fmt" "io" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) const version = protocol.Version1 // PrefixLen is the number of bytes used for configuration const PrefixLen = 1 func toEncLevel(v uint8) protocol.EncryptionLevel { switch v % 3 { default: return protocol.EncryptionInitial case 1: return protocol.EncryptionHandshake case 2: return protocol.Encryption1RTT } } // Fuzz fuzzes the QUIC frames. // //go:generate go run ./cmd/corpus.go func Fuzz(data []byte) int { if len(data) < PrefixLen { return 0 } encLevel := toEncLevel(data[0]) data = data[PrefixLen:] parser := wire.NewFrameParser(true, true, true) parser.SetAckDelayExponent(protocol.DefaultAckDelayExponent) var numFrames int var b []byte for len(data) > 0 { initialLen := len(data) frameType, l, err := parser.ParseType(data, encLevel) if err != nil { if err == io.EOF { // the last frame was a PADDING frame break } break } data = data[l:] numFrames++ var f wire.Frame switch { case frameType.IsStreamFrameType(): f, l, err = parser.ParseStreamFrame(frameType, data, version) case frameType == wire.FrameTypeAck || frameType == wire.FrameTypeAckECN: f, l, err = parser.ParseAckFrame(frameType, data, encLevel, version) case frameType == wire.FrameTypeDatagramNoLength || frameType == wire.FrameTypeDatagramWithLength: f, l, err = parser.ParseDatagramFrame(frameType, data, version) default: f, l, err = parser.ParseLessCommonFrame(frameType, data, version) } if err != nil { break } data = data[l:] wire.IsProbingFrame(f) ackhandler.IsFrameAckEliciting(f) // We accept empty STREAM frames, but we don't write them. if sf, ok := f.(*wire.StreamFrame); ok { if sf.DataLen() == 0 { sf.PutBack() continue } } validateFrame(f) startLen := len(b) parsedLen := initialLen - len(data) b, err = f.Append(b, version) if err != nil { panic(fmt.Sprintf("error writing frame %#v: %s", f, err)) } frameLen := protocol.ByteCount(len(b) - startLen) if f.Length(version) != frameLen { panic(fmt.Sprintf("inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version))) } if sf, ok := f.(*wire.StreamFrame); ok { sf.PutBack() } if frameLen > protocol.ByteCount(parsedLen) { panic(fmt.Sprintf("serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen)) } } if numFrames == 0 { return 0 } return 1 } func validateFrame(frame wire.Frame) { switch f := frame.(type) { case *wire.StreamFrame: if protocol.ByteCount(len(f.Data)) != f.DataLen() { panic("STREAM frame: inconsistent data length") } case *wire.AckFrame: if f.DelayTime < 0 { panic(fmt.Sprintf("invalid ACK delay_time: %s", f.DelayTime)) } if f.LargestAcked() < f.LowestAcked() { panic("ACK: largest acknowledged is smaller than lowest acknowledged") } for _, r := range f.AckRanges { if r.Largest < 0 || r.Smallest < 0 { panic("ACK range contains a negative packet number") } } if !f.AcksPacket(f.LargestAcked()) { panic("ACK frame claims that largest acknowledged is not acknowledged") } if !f.AcksPacket(f.LowestAcked()) { panic("ACK frame claims that lowest acknowledged is not acknowledged") } _ = f.AcksPacket(100) _ = f.AcksPacket((f.LargestAcked() + f.LowestAcked()) / 2) case *wire.NewConnectionIDFrame: if f.ConnectionID.Len() < 1 || f.ConnectionID.Len() > 20 { panic(fmt.Sprintf("invalid NEW_CONNECTION_ID frame length: %s", f.ConnectionID)) } case *wire.NewTokenFrame: if len(f.Token) == 0 { panic("NEW_TOKEN frame with an empty token") } case *wire.MaxStreamsFrame: if f.MaxStreamNum > protocol.MaxStreamCount { panic("MAX_STREAMS frame with an invalid Maximum Streams value") } case *wire.StreamsBlockedFrame: if f.StreamLimit > protocol.MaxStreamCount { panic("STREAMS_BLOCKED frame with an invalid Maximum Streams value") } case *wire.ConnectionCloseFrame: if f.IsApplicationError && f.FrameType != 0 { panic("CONNECTION_CLOSE for an application error containing a frame type") } case *wire.ResetStreamFrame: if f.FinalSize < f.ReliableSize { panic("RESET_STREAM frame with a FinalSize smaller than the ReliableSize") } case *wire.AckFrequencyFrame: if f.RequestMaxAckDelay < 0 { panic("ACK_FREQUENCY frame with a negative RequestMaxAckDelay") } } } quic-go-0.59.0/fuzzing/handshake/000077500000000000000000000000001513066070600165775ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/handshake/cmd/000077500000000000000000000000001513066070600173425ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/handshake/cmd/corpus.go000066400000000000000000000064101513066070600212050ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "log" "net" fuzzhandshake "github.com/quic-go/quic-go/fuzzing/handshake" "github.com/quic-go/quic-go/fuzzing/internal/helper" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/testdata" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) const alpn = "fuzz" func main() { client := handshake.NewCryptoSetupClient( protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &tls.Config{ MinVersion: tls.VersionTLS13, ServerName: "localhost", NextProtos: []string{alpn}, RootCAs: testdata.GetRootCA(), ClientSessionCache: tls.NewLRUClientSessionCache(1), }, false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("client"), protocol.Version1, ) config := testdata.GetTLSConfig() config.NextProtos = []string{alpn} server := handshake.NewCryptoSetupServer( protocol.ConnectionID{}, &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, config, false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) if err := client.StartHandshake(context.Background()); err != nil { log.Fatal(err) } if err := server.StartHandshake(context.Background()); err != nil { log.Fatal(err) } var clientHandshakeComplete, serverHandshakeComplete bool var messages [][]byte for { clientLoop: for { ev := client.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { case handshake.EventNoEvent: break clientLoop case handshake.EventWriteInitialData: messages = append(messages, ev.Data) if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { log.Fatal(err) } case handshake.EventWriteHandshakeData: messages = append(messages, ev.Data) if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { log.Fatal(err) } case handshake.EventHandshakeComplete: clientHandshakeComplete = true } } serverLoop: for { ev := server.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { case handshake.EventNoEvent: break serverLoop case handshake.EventWriteInitialData: messages = append(messages, ev.Data) if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { log.Fatal(err) } case handshake.EventWriteHandshakeData: messages = append(messages, ev.Data) if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { log.Fatal(err) } case handshake.EventHandshakeComplete: serverHandshakeComplete = true } } if serverHandshakeComplete && clientHandshakeComplete { break } } ticket, err := server.GetSessionTicket() if err != nil { log.Fatal(err) } if ticket == nil { log.Fatal("expected a session ticket") } messages = append(messages, ticket) for _, m := range messages { if err := helper.WriteCorpusFileWithPrefix("corpus", m, fuzzhandshake.PrefixLen); err != nil { log.Fatal(err) } } } quic-go-0.59.0/fuzzing/handshake/fuzz.go000066400000000000000000000270351513066070600201330ustar00rootroot00000000000000package handshake import ( "context" "crypto/ed25519" "crypto/rand" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "log" "math" mrand "math/rand/v2" "net" "github.com/quic-go/quic-go/fuzzing/internal/helper" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) var ( cert, clientCert *tls.Certificate certPool, clientCertPool *x509.CertPool sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} ) func init() { _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { log.Fatal(err) } cert, certPool, err = helper.GenerateCertificate(priv) if err != nil { log.Fatal(err) } _, privClient, err := ed25519.GenerateKey(rand.Reader) if err != nil { log.Fatal(err) } clientCert, clientCertPool, err = helper.GenerateCertificate(privClient) if err != nil { log.Fatal(err) } } type messageType uint8 // TLS handshake message types. const ( typeClientHello messageType = 1 typeServerHello messageType = 2 typeNewSessionTicket messageType = 4 typeEncryptedExtensions messageType = 8 typeCertificate messageType = 11 typeCertificateRequest messageType = 13 typeCertificateVerify messageType = 15 typeFinished messageType = 20 ) func (m messageType) String() string { switch m { case typeClientHello: return "ClientHello" case typeServerHello: return "ServerHello" case typeNewSessionTicket: return "NewSessionTicket" case typeEncryptedExtensions: return "EncryptedExtensions" case typeCertificate: return "Certificate" case typeCertificateRequest: return "CertificateRequest" case typeCertificateVerify: return "CertificateVerify" case typeFinished: return "Finished" default: return fmt.Sprintf("unknown message type: %d", m) } } // consumes 3 bits func getClientAuth(rand uint8) tls.ClientAuthType { switch rand { default: return tls.NoClientCert case 0: return tls.RequestClientCert case 1: return tls.RequireAnyClientCert case 2: return tls.VerifyClientCertIfGiven case 3: return tls.RequireAndVerifyClientCert } } const ( alpn = "fuzzing" alpnWrong = "wrong" ) func toEncryptionLevel(n uint8) protocol.EncryptionLevel { switch n % 3 { default: return protocol.EncryptionInitial case 1: return protocol.EncryptionHandshake case 2: return protocol.Encryption1RTT } } func getTransportParameters(seed uint8) *wire.TransportParameters { const maxVarInt = math.MaxUint64 / 4 r := mrand.New(mrand.NewPCG(uint64(seed), uint64(seed))) return &wire.TransportParameters{ ActiveConnectionIDLimit: 2, InitialMaxData: protocol.ByteCount(r.Uint64() % maxVarInt), InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Uint64() % maxVarInt), InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Uint64() % maxVarInt), InitialMaxStreamDataUni: protocol.ByteCount(r.Uint64() % maxVarInt), } } // PrefixLen is the number of bytes used for configuration const ( PrefixLen = 12 confLen = 5 ) // Fuzz fuzzes the TLS 1.3 handshake used by QUIC. // //go:generate go run ./cmd/corpus.go func Fuzz(data []byte) int { if len(data) < PrefixLen { return -1 } dataLen := len(data) var runConfig1, runConfig2 [confLen]byte copy(runConfig1[:], data) data = data[confLen:] messageConfig1 := data[0] data = data[1:] copy(runConfig2[:], data) data = data[confLen:] messageConfig2 := data[0] data = data[1:] if dataLen != len(data)+PrefixLen { panic("incorrect configuration") } clientConf := &tls.Config{ MinVersion: tls.VersionTLS13, ServerName: "localhost", NextProtos: []string{alpn}, RootCAs: certPool, } useSessionTicketCache := helper.NthBit(runConfig1[0], 2) if useSessionTicketCache { clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5) } if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 { return val } return runHandshake(runConfig2, messageConfig2, clientConf, data) } func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int { serverConf := &tls.Config{ MinVersion: tls.VersionTLS13, Certificates: []tls.Certificate{*cert}, NextProtos: []string{alpn}, SessionTicketKey: sessionTicketKey, } // This sets the cipher suite for both client and server. // The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server. resetCipherSuite := func() {} switch (runConfig[0] >> 6) % 4 { case 0: resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256) case 1: resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384) case 3: resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) default: } defer resetCipherSuite() enable0RTTClient := helper.NthBit(runConfig[0], 0) enable0RTTServer := helper.NthBit(runConfig[0], 1) sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) sendSessionTicket := helper.NthBit(runConfig[0], 5) serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) if helper.NthBit(runConfig[2], 0) { clientConf.RootCAs = x509.NewCertPool() } if helper.NthBit(runConfig[2], 1) { serverConf.ClientCAs = clientCertPool } else { serverConf.ClientCAs = x509.NewCertPool() } if helper.NthBit(runConfig[2], 2) { serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { if helper.NthBit(runConfig[2], 3) { return nil, errors.New("getting client config failed") } if helper.NthBit(runConfig[2], 4) { return nil, nil } return serverConf, nil } } if helper.NthBit(runConfig[2], 5) { serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { if helper.NthBit(runConfig[2], 6) { return nil, errors.New("getting certificate failed") } if helper.NthBit(runConfig[2], 7) { return nil, nil } return clientCert, nil // this certificate will be invalid } } if helper.NthBit(runConfig[3], 0) { serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { if helper.NthBit(runConfig[3], 1) { return errors.New("certificate verification failed") } return nil } } if helper.NthBit(runConfig[3], 2) { clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { if helper.NthBit(runConfig[3], 3) { return errors.New("certificate verification failed") } return nil } } if helper.NthBit(runConfig[3], 4) { serverConf.NextProtos = []string{alpnWrong} } if helper.NthBit(runConfig[3], 5) { serverConf.NextProtos = []string{alpnWrong, alpn} } if helper.NthBit(runConfig[3], 6) { serverConf.KeyLogWriter = io.Discard } if helper.NthBit(runConfig[3], 7) { clientConf.KeyLogWriter = io.Discard } clientTP := getTransportParameters(runConfig[4] & 0x3) if helper.NthBit(runConfig[4], 3) { clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 } serverTP := getTransportParameters(runConfig[4] & 0b00011000) if helper.NthBit(runConfig[4], 3) { serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5 } messageToReplace := messageConfig % 32 messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6) if len(data) == 0 { return -1 } client := handshake.NewCryptoSetupClient( protocol.ConnectionID{}, clientTP, clientConf, enable0RTTClient, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("client"), protocol.Version1, ) if err := client.StartHandshake(context.Background()); err != nil { log.Fatal(err) } defer client.Close() server := handshake.NewCryptoSetupServer( protocol.ConnectionID{}, &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, serverTP, serverConf, enable0RTTServer, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) if err := server.StartHandshake(context.Background()); err != nil { log.Fatal(err) } defer server.Close() var clientHandshakeComplete, serverHandshakeComplete bool for { var processedEvent bool clientLoop: for { ev := client.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { case handshake.EventNoEvent: if !processedEvent && !clientHandshakeComplete { // handshake stuck return 1 } break clientLoop case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: msg := ev.Data encLevel := protocol.EncryptionInitial if ev.Kind == handshake.EventWriteHandshakeData { encLevel = protocol.EncryptionHandshake } if msg[0] == messageToReplace { fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) msg = data encLevel = messageToReplaceEncLevel } if err := server.HandleMessage(msg, encLevel); err != nil { return 1 } case handshake.EventHandshakeComplete: clientHandshakeComplete = true } processedEvent = true } processedEvent = false serverLoop: for { ev := server.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { case handshake.EventNoEvent: if !processedEvent && !serverHandshakeComplete { // handshake stuck return 1 } break serverLoop case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: encLevel := protocol.EncryptionInitial if ev.Kind == handshake.EventWriteHandshakeData { encLevel = protocol.EncryptionHandshake } msg := ev.Data if msg[0] == messageToReplace { fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) msg = data encLevel = messageToReplaceEncLevel } if err := client.HandleMessage(msg, encLevel); err != nil { return 1 } case handshake.EventHandshakeComplete: serverHandshakeComplete = true } processedEvent = true } if serverHandshakeComplete && clientHandshakeComplete { break } } _ = client.ConnectionState() _ = server.ConnectionState() sealer, err := client.Get1RTTSealer() if err != nil { panic("expected to get a 1-RTT sealer") } opener, err := server.Get1RTTOpener() if err != nil { panic("expected to get a 1-RTT opener") } const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar")) decrypted, err := opener.Open(nil, encrypted, 0, 1337, protocol.KeyPhaseZero, []byte("foobar")) if err != nil { panic(fmt.Sprintf("Decrypting message failed: %s", err.Error())) } if string(decrypted) != msg { panic("wrong message") } if sendSessionTicket && !serverConf.SessionTicketsDisabled { ticket, err := server.GetSessionTicket() if err != nil { panic(err) } if ticket == nil { panic("empty ticket") } client.HandleMessage(ticket, protocol.Encryption1RTT) } if sendPostHandshakeMessageToClient { fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel) client.HandleMessage(data, messageToReplaceEncLevel) } if sendPostHandshakeMessageToServer { fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel) server.HandleMessage(data, messageToReplaceEncLevel) } return 1 } quic-go-0.59.0/fuzzing/header/000077500000000000000000000000001513066070600161015ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/header/cmd/000077500000000000000000000000001513066070600166445ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/header/cmd/corpus.go000066400000000000000000000105561513066070600205150ustar00rootroot00000000000000package main import ( "crypto/rand" "log" mrand "math/rand/v2" "github.com/quic-go/quic-go/fuzzing/header" "github.com/quic-go/quic-go/fuzzing/internal/helper" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) const version = protocol.Version1 func getRandomData(l int) []byte { b := make([]byte, l) rand.Read(b) return b } func getVNP(src, dest protocol.ArbitraryLenConnectionID, numVersions int) []byte { versions := make([]protocol.Version, numVersions) for i := 0; i < numVersions; i++ { versions[i] = protocol.Version(mrand.Uint32()) } return wire.ComposeVersionNegotiation(src, dest, versions) } func main() { headers := []wire.Header{ { // Initial without token SrcConnectionID: protocol.ParseConnectionID(getRandomData(3)), DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(mrand.IntN(1000)), Version: version, }, { // Initial without token, with zero-length src conn id DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(mrand.IntN(1000)), Version: version, }, { // Initial with Token SrcConnectionID: protocol.ParseConnectionID(getRandomData(10)), DestConnectionID: protocol.ParseConnectionID(getRandomData(19)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(mrand.IntN(1000)), Version: version, Token: getRandomData(25), }, { // Handshake packet SrcConnectionID: protocol.ParseConnectionID(getRandomData(5)), DestConnectionID: protocol.ParseConnectionID(getRandomData(10)), Type: protocol.PacketTypeHandshake, Length: protocol.ByteCount(mrand.IntN(1000)), Version: version, }, { // Handshake packet, with zero-length src conn id DestConnectionID: protocol.ParseConnectionID(getRandomData(12)), Type: protocol.PacketTypeHandshake, Length: protocol.ByteCount(mrand.IntN(1000)), Version: version, }, { // 0-RTT packet SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)), DestConnectionID: protocol.ParseConnectionID(getRandomData(9)), Type: protocol.PacketType0RTT, Length: protocol.ByteCount(mrand.IntN(1000)), Version: version, }, { // Retry Packet, with empty orig dest conn id SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)), DestConnectionID: protocol.ParseConnectionID(getRandomData(9)), Type: protocol.PacketTypeRetry, Token: getRandomData(1000), Version: version, }, } for _, h := range headers { extHdr := &wire.ExtendedHeader{ Header: h, PacketNumberLen: protocol.PacketNumberLen(mrand.IntN(4) + 1), PacketNumber: protocol.PacketNumber(mrand.Uint64()), } b, err := extHdr.Append(nil, version) if err != nil { log.Fatal(err) } if h.Type == protocol.PacketTypeRetry { b = append(b, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}...) } if h.Length > 0 { b = append(b, make([]byte, h.Length)...) } if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil { log.Fatal(err) } } // short header b, err := wire.AppendShortHeader(nil, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne) if err != nil { log.Fatal(err) } if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil { log.Fatal(err) } vnps := [][]byte{ getVNP( protocol.ArbitraryLenConnectionID(getRandomData(8)), protocol.ArbitraryLenConnectionID(getRandomData(10)), 4, ), getVNP( protocol.ArbitraryLenConnectionID(getRandomData(10)), protocol.ArbitraryLenConnectionID(getRandomData(5)), 0, ), getVNP( protocol.ArbitraryLenConnectionID(getRandomData(3)), protocol.ArbitraryLenConnectionID(getRandomData(19)), 100, ), getVNP( protocol.ArbitraryLenConnectionID(getRandomData(3)), nil, 20, ), getVNP( nil, protocol.ArbitraryLenConnectionID(getRandomData(10)), 5, ), } for _, vnp := range vnps { if err := helper.WriteCorpusFileWithPrefix("corpus", vnp, header.PrefixLen); err != nil { log.Fatal(err) } } } quic-go-0.59.0/fuzzing/header/fuzz.go000066400000000000000000000050561513066070600174340ustar00rootroot00000000000000package header import ( "bytes" "fmt" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) const version = protocol.Version1 // PrefixLen is the number of bytes used for configuration const PrefixLen = 1 // Fuzz fuzzes the QUIC header. // //go:generate go run ./cmd/corpus.go func Fuzz(data []byte) int { if len(data) < PrefixLen { return 0 } connIDLen := int(data[0] % 21) data = data[PrefixLen:] if wire.IsVersionNegotiationPacket(data) { return fuzzVNP(data) } connID, err := wire.ParseConnectionID(data, connIDLen) if err != nil { return 0 } if !wire.IsLongHeaderPacket(data[0]) { wire.ParseShortHeader(data, connIDLen) return 1 } is0RTTPacket := wire.Is0RTTPacket(data) hdr, _, _, err := wire.ParsePacket(data) if err != nil { return 0 } if hdr.DestConnectionID != connID { panic(fmt.Sprintf("Expected connection IDs to match: %s vs %s", hdr.DestConnectionID, connID)) } if (hdr.Type == protocol.PacketType0RTT) != is0RTTPacket { panic("inconsistent 0-RTT packet detection") } var extHdr *wire.ExtendedHeader // Parse the extended header, if this is not a Retry packet. if hdr.Type == protocol.PacketTypeRetry { extHdr = &wire.ExtendedHeader{Header: *hdr} } else { var err error extHdr, err = hdr.ParseExtended(data) if err != nil { return 0 } } // We always use a 2-byte encoding for the Length field in Long Header packets. // Serializing the header will fail when using a higher value. if hdr.Length > 16383 { return 1 } b, err := extHdr.Append(nil, version) if err != nil { // We are able to parse packets with connection IDs longer than 20 bytes, // but in QUIC version 1, we don't write headers with longer connection IDs. if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen && hdr.SrcConnectionID.Len() <= protocol.MaxConnIDLen { panic(err) } return 0 } // GetLength is not implemented for Retry packets if hdr.Type != protocol.PacketTypeRetry { if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(len(b)) { panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, len(b))) } } return 1 } func fuzzVNP(data []byte) int { connID, err := wire.ParseConnectionID(data, 0) if err != nil { return 0 } dest, src, versions, err := wire.ParseVersionNegotiationPacket(data) if err != nil { return 0 } if !bytes.Equal(dest, connID.Bytes()) { panic("connection IDs don't match") } if len(versions) == 0 { panic("no versions") } wire.ComposeVersionNegotiation(src, dest, versions) return 1 } quic-go-0.59.0/fuzzing/internal/000077500000000000000000000000001513066070600164655ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/internal/helper/000077500000000000000000000000001513066070600177445ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/internal/helper/helper.go000066400000000000000000000043631513066070600215600ustar00rootroot00000000000000package helper import ( "crypto" "crypto/rand" "crypto/sha1" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/hex" "math/big" "os" "path/filepath" "time" ) // NthBit gets the n-th bit of a byte (counting starts at 0). func NthBit(val uint8, n int) bool { if n < 0 || n > 7 { panic("invalid value for n") } return val>>n&0x1 == 1 } // WriteCorpusFile writes data to a corpus file in directory path. // The filename is calculated from the SHA1 sum of the file contents. func WriteCorpusFile(path string, data []byte) error { // create the directory, if it doesn't exist yet if _, err := os.Stat(path); os.IsNotExist(err) { if err := os.MkdirAll(path, os.ModePerm); err != nil { return err } } hash := sha1.Sum(data) return os.WriteFile(filepath.Join(path, hex.EncodeToString(hash[:])), data, 0o644) } // WriteCorpusFileWithPrefix writes data to a corpus file in directory path. // In many fuzzers, the first n bytes are used to control. // This function prepends n zero-bytes to the data. func WriteCorpusFileWithPrefix(path string, data []byte, n int) error { return WriteCorpusFile(path, append(make([]byte, n), data...)) } // GenerateCertificate generates a self-signed certificate. // It returns the certificate and a x509.CertPool containing that certificate. func GenerateCertificate(priv crypto.Signer) (*tls.Certificate, *x509.CertPool, error) { template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{Organization: []string{"quic-go fuzzer"}}, NotBefore: time.Now().Add(-24 * time.Hour), NotAfter: time.Now().Add(30 * 24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, DNSNames: []string{"localhost"}, BasicConstraintsValid: true, } derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) if err != nil { return nil, nil, err } cert, err := x509.ParseCertificate(derBytes) if err != nil { return nil, nil, err } certPool := x509.NewCertPool() certPool.AddCert(cert) return &tls.Certificate{ Certificate: [][]byte{derBytes}, PrivateKey: priv, }, certPool, nil } quic-go-0.59.0/fuzzing/internal/helper/helper_test.go000066400000000000000000000030261513066070600226120ustar00rootroot00000000000000package helper import ( "os" "path/filepath" "testing" "github.com/stretchr/testify/require" ) func TestWriteCorpusFile(t *testing.T) { const data = "lorem ipsum" const expectedShaSum = "bfb7759a67daeb65410490b4d98bb9da7d1ea2ce" dir := t.TempDir() require.NoError(t, WriteCorpusFile(dir, []byte(data))) path := filepath.Join(dir, expectedShaSum) require.FileExists(t, path) b, err := os.ReadFile(path) require.NoError(t, err) require.Equal(t, data, string(b)) } func TestWriteCorpusFileWithPrefix(t *testing.T) { const data = "lorem ipsum" const expectedShaSum = "523f5cab80fab0c7889dbf50dd310ab8c8879f9c" const prefixLen = 7 dir := t.TempDir() require.NoError(t, WriteCorpusFileWithPrefix(dir, []byte(data), prefixLen)) path := filepath.Join(dir, expectedShaSum) require.FileExists(t, path) b, err := os.ReadFile(path) require.NoError(t, err) require.Equal(t, make([]byte, prefixLen), b[:prefixLen]) require.Equal(t, data, string(b[prefixLen:])) } func TestCreateDirectoryIfNotExists(t *testing.T) { dir := t.TempDir() subdir := filepath.Join(dir, "corpus") require.NoDirExists(t, subdir) require.NoError(t, WriteCorpusFile(subdir, []byte("lorem ipsum"))) require.DirExists(t, subdir) } func TestNthBit(t *testing.T) { const val = 0b10010001 require.True(t, NthBit(val, 0)) require.False(t, NthBit(val, 1)) require.False(t, NthBit(val, 2)) require.False(t, NthBit(val, 3)) require.True(t, NthBit(val, 4)) require.False(t, NthBit(val, 5)) require.False(t, NthBit(val, 6)) require.True(t, NthBit(val, 7)) } quic-go-0.59.0/fuzzing/tokens/000077500000000000000000000000001513066070600161545ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/tokens/fuzz.go000066400000000000000000000060041513066070600175010ustar00rootroot00000000000000package tokens import ( "encoding/binary" "net" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" ) func Fuzz(data []byte) int { if len(data) < 32 { return -1 } var key quic.TokenGeneratorKey copy(key[:], data[:32]) data = data[32:] tg := handshake.NewTokenGenerator(key) if len(data) < 1 { return -1 } s := data[0] % 3 data = data[1:] switch s { case 0: tg.DecodeToken(data) return 1 case 1: return newToken(tg, data) case 2: return newRetryToken(tg, data) } return -1 } func newToken(tg *handshake.TokenGenerator, data []byte) int { if len(data) < 1 { return -1 } usesUDPAddr := data[0]%2 == 0 data = data[1:] if len(data) < 18 { return -1 } var addr net.Addr if usesUDPAddr { addr = &net.UDPAddr{ Port: int(binary.BigEndian.Uint16(data[:2])), IP: net.IP(data[2:18]), } } else { addr = &net.TCPAddr{ Port: int(binary.BigEndian.Uint16(data[:2])), IP: net.IP(data[2:18]), } } data = data[18:] if len(data) < 1 { return -1 } start := time.Now() encrypted, err := tg.NewToken(addr, time.Duration(data[0])*time.Millisecond) if err != nil { panic(err) } token, err := tg.DecodeToken(encrypted) if err != nil { panic(err) } if token.IsRetryToken { panic("didn't encode a Retry token") } if token.SentTime.Before(start) || token.SentTime.After(time.Now()) { panic("incorrect send time") } if token.OriginalDestConnectionID.Len() > 0 || token.RetrySrcConnectionID.Len() > 0 { panic("didn't expect connection IDs") } return 1 } func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if len(data) < 2 { return -1 } origDestConnIDLen := int(data[0] % 21) retrySrcConnIDLen := int(data[1] % 21) data = data[2:] if len(data) < origDestConnIDLen { return -1 } origDestConnID := protocol.ParseConnectionID(data[:origDestConnIDLen]) data = data[origDestConnIDLen:] if len(data) < retrySrcConnIDLen { return -1 } retrySrcConnID := protocol.ParseConnectionID(data[:retrySrcConnIDLen]) data = data[retrySrcConnIDLen:] if len(data) < 1 { return -1 } usesUDPAddr := data[0]%2 == 0 data = data[1:] if len(data) != 18 { return -1 } start := time.Now() var addr net.Addr if usesUDPAddr { addr = &net.UDPAddr{ Port: int(binary.BigEndian.Uint16(data[:2])), IP: net.IP(data[2:]), } } else { addr = &net.TCPAddr{ Port: int(binary.BigEndian.Uint16(data[:2])), IP: net.IP(data[2:]), } } encrypted, err := tg.NewRetryToken(addr, origDestConnID, retrySrcConnID) if err != nil { panic(err) } token, err := tg.DecodeToken(encrypted) if err != nil { panic(err) } if !token.IsRetryToken { panic("expected a Retry token") } if token.SentTime.Before(start) || token.SentTime.After(time.Now()) { panic("incorrect send time") } if token.OriginalDestConnectionID != origDestConnID { panic("orig dest conn ID doesn't match") } if token.RetrySrcConnectionID != retrySrcConnID { panic("retry src conn ID doesn't match") } return 1 } quic-go-0.59.0/fuzzing/transportparameters/000077500000000000000000000000001513066070600207715ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/transportparameters/cmd/000077500000000000000000000000001513066070600215345ustar00rootroot00000000000000quic-go-0.59.0/fuzzing/transportparameters/cmd/corpus.go000066400000000000000000000055061513066070600234040ustar00rootroot00000000000000package main import ( "crypto/rand" "log" "math" mrand "math/rand/v2" "net/netip" "time" "github.com/quic-go/quic-go/fuzzing/internal/helper" "github.com/quic-go/quic-go/fuzzing/transportparameters" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) func getRandomData(l int) []byte { b := make([]byte, l) rand.Read(b) return b } func getRandomValue() uint64 { maxVals := []uint64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} return mrand.Uint64N(maxVals[mrand.IntN(4)]) } func main() { for i := 0; i < 30; i++ { tp := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), InitialMaxData: protocol.ByteCount(getRandomValue()), MaxAckDelay: time.Duration(getRandomValue()), AckDelayExponent: uint8(getRandomValue()), DisableActiveMigration: getRandomValue()%2 == 0, MaxUDPPayloadSize: protocol.ByteCount(getRandomValue()), MaxUniStreamNum: protocol.StreamNum(getRandomValue()), MaxBidiStreamNum: protocol.StreamNum(getRandomValue()), MaxIdleTimeout: time.Duration(getRandomValue()), ActiveConnectionIDLimit: getRandomValue() + 2, } if mrand.IntN(2) == 0 { tp.OriginalDestinationConnectionID = protocol.ParseConnectionID(getRandomData(mrand.IntN(21))) } if mrand.IntN(2) == 0 { tp.InitialSourceConnectionID = protocol.ParseConnectionID(getRandomData(mrand.IntN(21))) } if mrand.IntN(2) == 0 { connID := protocol.ParseConnectionID(getRandomData(mrand.IntN(21))) tp.RetrySourceConnectionID = &connID } if mrand.IntN(2) == 0 { var token protocol.StatelessResetToken rand.Read(token[:]) tp.StatelessResetToken = &token } if mrand.IntN(2) == 0 { var token protocol.StatelessResetToken rand.Read(token[:]) var ip4 [4]byte rand.Read(ip4[:]) var ip6 [16]byte rand.Read(ip6[:]) tp.PreferredAddress = &wire.PreferredAddress{ IPv4: netip.AddrPortFrom(netip.AddrFrom4(ip4), uint16(mrand.Int())), IPv6: netip.AddrPortFrom(netip.AddrFrom16(ip6), uint16(mrand.Int())), ConnectionID: protocol.ParseConnectionID(getRandomData(mrand.IntN(21))), StatelessResetToken: token, } } var data []byte if mrand.Int()%2 == 0 { pers := protocol.PerspectiveServer if mrand.Int()%2 == 0 { pers = protocol.PerspectiveClient } data = tp.Marshal(pers) } else { data = tp.MarshalForSessionTicket(nil) } if err := helper.WriteCorpusFileWithPrefix("corpus", data, transportparameters.PrefixLen); err != nil { log.Fatal(err) } } } quic-go-0.59.0/fuzzing/transportparameters/fuzz.go000066400000000000000000000063101513066070600223160ustar00rootroot00000000000000package transportparameters import ( "errors" "fmt" "github.com/quic-go/quic-go/fuzzing/internal/helper" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) // PrefixLen is the number of bytes used for configuration const PrefixLen = 1 // Fuzz fuzzes the QUIC transport parameters. // //go:generate go run ./cmd/corpus.go func Fuzz(data []byte) int { if len(data) <= PrefixLen { return 0 } if helper.NthBit(data[0], 0) { return fuzzTransportParametersForSessionTicket(data[PrefixLen:]) } return fuzzTransportParameters(data[PrefixLen:], helper.NthBit(data[0], 1)) } func fuzzTransportParameters(data []byte, sentByServer bool) int { sentBy := protocol.PerspectiveClient if sentByServer { sentBy = protocol.PerspectiveServer } tp := &wire.TransportParameters{} if err := tp.Unmarshal(data, sentBy); err != nil { return 0 } _ = tp.String() if err := validateTransportParameters(tp, sentBy); err != nil { panic(err) } tp2 := &wire.TransportParameters{} if err := tp2.Unmarshal(tp.Marshal(sentBy), sentBy); err != nil { fmt.Printf("%#v\n", tp) panic(err) } if err := validateTransportParameters(tp2, sentBy); err != nil { panic(err) } return 1 } func fuzzTransportParametersForSessionTicket(data []byte) int { tp := &wire.TransportParameters{} if err := tp.UnmarshalFromSessionTicket(data); err != nil { return 0 } b := tp.MarshalForSessionTicket(nil) tp2 := &wire.TransportParameters{} if err := tp2.UnmarshalFromSessionTicket(b); err != nil { panic(err) } return 1 } func validateTransportParameters(tp *wire.TransportParameters, sentBy protocol.Perspective) error { if sentBy == protocol.PerspectiveClient && tp.StatelessResetToken != nil { return errors.New("client's transport parameters contained stateless reset token") } if tp.MaxIdleTimeout < 0 { return fmt.Errorf("negative max_idle_timeout: %s", tp.MaxIdleTimeout) } if tp.AckDelayExponent > 20 { return fmt.Errorf("invalid ack_delay_exponent: %d", tp.AckDelayExponent) } if tp.MaxUDPPayloadSize < 1200 { return fmt.Errorf("invalid max_udp_payload_size: %d", tp.MaxUDPPayloadSize) } if tp.ActiveConnectionIDLimit < 2 { return fmt.Errorf("invalid active_connection_id_limit: %d", tp.ActiveConnectionIDLimit) } if tp.OriginalDestinationConnectionID.Len() > 20 { return fmt.Errorf("invalid original_destination_connection_id length: %s", tp.InitialSourceConnectionID) } if tp.InitialSourceConnectionID.Len() > 20 { return fmt.Errorf("invalid initial_source_connection_id length: %s", tp.InitialSourceConnectionID) } if tp.RetrySourceConnectionID != nil && tp.RetrySourceConnectionID.Len() > 20 { return fmt.Errorf("invalid retry_source_connection_id length: %s", tp.RetrySourceConnectionID) } if tp.PreferredAddress != nil && tp.PreferredAddress.ConnectionID.Len() > 20 { return fmt.Errorf("invalid preferred_address connection ID length: %s", tp.PreferredAddress.ConnectionID) } if tp.MinAckDelay != nil { if *tp.MinAckDelay < 0 { return fmt.Errorf("negative min_ack_delay: %s", *tp.MinAckDelay) } if *tp.MinAckDelay > tp.MaxAckDelay { return fmt.Errorf("min_ack_delay (%s) is greater than max_ack_delay (%s)", *tp.MinAckDelay, tp.MaxAckDelay) } } return nil } quic-go-0.59.0/go.mod000066400000000000000000000015151513066070600142650ustar00rootroot00000000000000module github.com/quic-go/quic-go go 1.24 require ( github.com/quic-go/qpack v0.6.0 github.com/stretchr/testify v1.11.1 go.uber.org/mock v0.5.2 golang.org/x/crypto v0.41.0 golang.org/x/net v0.43.0 golang.org/x/sync v0.16.0 golang.org/x/sys v0.35.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/jordanlewis/gcassert v0.0.0-20250430164644-389ef753e22e // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect golang.org/x/mod v0.27.0 // indirect golang.org/x/text v0.28.0 // indirect golang.org/x/tools v0.36.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) tool ( github.com/jordanlewis/gcassert/cmd/gcassert go.uber.org/mock/mockgen ) quic-go-0.59.0/go.sum000066400000000000000000000210331513066070600143070ustar00rootroot00000000000000github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jordanlewis/gcassert v0.0.0-20250430164644-389ef753e22e h1:a+PGEeXb+exwBS3NboqXHyxarD9kaboBbrSp+7GuBuc= github.com/jordanlewis/gcassert v0.0.0-20250430164644-389ef753e22e/go.mod h1:ZybsQk6DWyN5t7An1MuPm1gtSZ1xDaTXS9ZjIOxvQrk= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= quic-go-0.59.0/http3/000077500000000000000000000000001513066070600142175ustar00rootroot00000000000000quic-go-0.59.0/http3/README.md000066400000000000000000000012421513066070600154750ustar00rootroot00000000000000# HTTP/3 [![Documentation](https://img.shields.io/badge/docs-quic--go.net-red?style=flat)](https://quic-go.net/docs/) [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go/http3)](https://pkg.go.dev/github.com/quic-go/quic-go/http3) This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)). It aims to provide feature parity with the standard library's HTTP/1.1 and HTTP/2 implementation. Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/). quic-go-0.59.0/http3/body.go000066400000000000000000000066471513066070600155200ustar00rootroot00000000000000package http3 import ( "context" "errors" "io" "sync" "github.com/quic-go/quic-go" ) // Settingser allows waiting for and retrieving the peer's HTTP/3 settings. type Settingser interface { // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. // Settings can be obtained from the Settings method after the channel was closed. ReceivedSettings() <-chan struct{} // Settings returns the settings received on this connection. // It is only valid to call this function after the channel returned by ReceivedSettings was closed. Settings() *Settings } var errTooMuchData = errors.New("peer sent too much data") // The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response). type body struct { str *Stream remainingContentLength int64 violatedContentLength bool hasContentLength bool } func newBody(str *Stream, contentLength int64) *body { b := &body{str: str} if contentLength >= 0 { b.hasContentLength = true b.remainingContentLength = contentLength } return b } func (r *body) StreamID() quic.StreamID { return r.str.StreamID() } func (r *body) checkContentLengthViolation() error { if !r.hasContentLength { return nil } if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() { if !r.violatedContentLength { r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) r.violatedContentLength = true } return errTooMuchData } return nil } func (r *body) Read(b []byte) (int, error) { if err := r.checkContentLengthViolation(); err != nil { return 0, err } if r.hasContentLength { b = b[:min(int64(len(b)), r.remainingContentLength)] } n, err := r.str.Read(b) r.remainingContentLength -= int64(n) if err := r.checkContentLengthViolation(); err != nil { return n, err } return n, maybeReplaceError(err) } func (r *body) Close() error { r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil } type requestBody struct { body connCtx context.Context rcvdSettings <-chan struct{} getSettings func() *Settings } var _ io.ReadCloser = &requestBody{} func newRequestBody(str *Stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody { return &requestBody{ body: *newBody(str, contentLength), connCtx: connCtx, rcvdSettings: rcvdSettings, getSettings: getSettings, } } type hijackableBody struct { body body // only set for the http.Response // The channel is closed when the user is done with this response: // either when Read() errors, or when Close() is called. reqDone chan<- struct{} reqDoneOnce sync.Once } var _ io.ReadCloser = &hijackableBody{} func newResponseBody(str *Stream, contentLength int64, done chan<- struct{}) *hijackableBody { return &hijackableBody{ body: *newBody(str, contentLength), reqDone: done, } } func (r *hijackableBody) Read(b []byte) (int, error) { n, err := r.body.Read(b) if err != nil { r.requestDone() } return n, maybeReplaceError(err) } func (r *hijackableBody) requestDone() { if r.reqDone != nil { r.reqDoneOnce.Do(func() { close(r.reqDone) }) } } func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil } quic-go-0.59.0/http3/body_test.go000066400000000000000000000075031513066070600165470ustar00rootroot00000000000000package http3 import ( "bytes" "io" "testing" "time" "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func TestResponseBodyReading(t *testing.T) { mockCtrl := gomock.NewController(t) var buf bytes.Buffer buf.Write(getDataFrame([]byte("foobar"))) str := NewMockDatagramStream(mockCtrl) str.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() reqDone := make(chan struct{}) rb := newResponseBody( newStream(str, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), -1, reqDone, ) data, err := io.ReadAll(rb) require.NoError(t, err) require.Equal(t, []byte("foobar"), data) } func TestResponseBodyReadError(t *testing.T) { mockCtrl := gomock.NewController(t) str := NewMockDatagramStream(mockCtrl) str.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() str.EXPECT().Read(gomock.Any()).Return(0, assert.AnError).Times(2) reqDone := make(chan struct{}) rb := newResponseBody( newStream(str, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), -1, reqDone, ) _, err := rb.Read([]byte{0}) require.ErrorIs(t, err, assert.AnError) // repeated calls to Read should return the same error _, err = rb.Read([]byte{0}) require.ErrorIs(t, err, assert.AnError) select { case <-reqDone: default: t.Fatal("reqDone should be closed") } } func TestResponseBodyClose(t *testing.T) { mockCtrl := gomock.NewController(t) str := NewMockDatagramStream(mockCtrl) str.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Times(2) reqDone := make(chan struct{}) rb := newResponseBody( newStream(str, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), -1, reqDone, ) require.NoError(t, rb.Close()) select { case <-reqDone: default: t.Fatal("reqDone should be closed") } // multiple calls to Close should be a no-op require.NoError(t, rb.Close()) } func TestResponseBodyConcurrentClose(t *testing.T) { mockCtrl := gomock.NewController(t) str := NewMockDatagramStream(mockCtrl) str.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).MaxTimes(3) reqDone := make(chan struct{}) rb := newResponseBody( newStream(str, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), -1, reqDone, ) for range 3 { go rb.Close() } select { case <-reqDone: case <-time.After(time.Second): t.Fatal("reqDone should be closed") } } func TestResponseBodyLengthLimiting(t *testing.T) { t.Run("along frame boundary", func(t *testing.T) { testResponseBodyLengthLimiting(t, true) }) t.Run("in the middle of a frame", func(t *testing.T) { testResponseBodyLengthLimiting(t, false) }) } func testResponseBodyLengthLimiting(t *testing.T, alongFrameBoundary bool) { var buf bytes.Buffer buf.Write(getDataFrame([]byte("foo"))) buf.Write(getDataFrame([]byte("bar"))) l := int64(4) if alongFrameBoundary { l = 3 } mockCtrl := gomock.NewController(t) str := NewMockDatagramStream(mockCtrl) str.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() rb := newResponseBody( newStream(str, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), l, make(chan struct{}), ) data, err := io.ReadAll(rb) require.Equal(t, []byte("foobar")[:l], data) require.ErrorIs(t, err, errTooMuchData) // check that repeated calls to Read also return the right error n, err := rb.Read([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, errTooMuchData) } quic-go-0.59.0/http3/capsule.go000066400000000000000000000031451513066070600162050ustar00rootroot00000000000000package http3 import ( "io" "github.com/quic-go/quic-go/quicvarint" ) // CapsuleType is the type of the capsule type CapsuleType uint64 // CapsuleProtocolHeader is the header value used to advertise support for the capsule protocol const CapsuleProtocolHeader = "Capsule-Protocol" type exactReader struct { R io.LimitedReader } func (r *exactReader) Read(b []byte) (int, error) { n, err := r.R.Read(b) if err == io.EOF && r.R.N > 0 { return n, io.ErrUnexpectedEOF } return n, err } // ParseCapsule parses the header of a Capsule. // It returns an io.Reader that can be used to read the Capsule value. // The Capsule value must be read entirely (i.e. until the io.EOF) before using r again. func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) { cbr := countingByteReader{Reader: r} ct, err := quicvarint.Read(&cbr) if err != nil { // If an io.EOF is returned without consuming any bytes, return it unmodified. // Otherwise, return an io.ErrUnexpectedEOF. if err == io.EOF && cbr.NumRead > 0 { return 0, nil, io.ErrUnexpectedEOF } return 0, nil, err } l, err := quicvarint.Read(r) if err != nil { if err == io.EOF { return 0, nil, io.ErrUnexpectedEOF } return 0, nil, err } return CapsuleType(ct), &exactReader{R: io.LimitedReader{R: r, N: int64(l)}}, nil } // WriteCapsule writes a capsule func WriteCapsule(w quicvarint.Writer, ct CapsuleType, value []byte) error { b := make([]byte, 0, 16) b = quicvarint.Append(b, uint64(ct)) b = quicvarint.Append(b, uint64(len(value))) if _, err := w.Write(b); err != nil { return err } _, err := w.Write(value) return err } quic-go-0.59.0/http3/capsule_test.go000066400000000000000000000051371513066070600172470ustar00rootroot00000000000000package http3 import ( "bytes" "io" "testing" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestCapsuleParsing(t *testing.T) { b := quicvarint.Append(nil, 1337) b = quicvarint.Append(b, 6) b = append(b, []byte("foobar")...) ct, r, err := ParseCapsule(bytes.NewReader(b)) require.NoError(t, err) require.Equal(t, CapsuleType(1337), ct) buf := make([]byte, 3) n, err := r.Read(buf) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("foo"), buf) data, err := io.ReadAll(r) // reads until EOF require.NoError(t, err) require.Equal(t, []byte("bar"), data) } func TestEmptyCapsuleParsing(t *testing.T) { b := quicvarint.Append(nil, 1337) b = quicvarint.Append(b, 0) // Capsule content is empty. ct, r, err := ParseCapsule(bytes.NewReader(b)) require.NoError(t, err) require.Equal(t, CapsuleType(1337), ct) data, err := io.ReadAll(r) // reads until EOF require.NoError(t, err) require.Equal(t, []byte{}, data) } // test EOF vs ErrUnexpectedEOF func TestCapsuleTruncation(t *testing.T) { t.Run("with content", func(t *testing.T) { b := quicvarint.Append(nil, 1337) b = quicvarint.Append(b, 6) b = append(b, []byte("foobar")...) testCapsuleTruncation(t, b) }) t.Run("empty content", func(t *testing.T) { b := quicvarint.Append(nil, 1337) b = quicvarint.Append(b, 0) testCapsuleTruncation(t, b) }) } func testCapsuleTruncation(t *testing.T, b []byte) { for i := range b { ct, r, err := ParseCapsule(bytes.NewReader(b[:i])) if err != nil { if i == 0 { require.ErrorIs(t, err, io.EOF) } else { require.ErrorIs(t, err, io.ErrUnexpectedEOF) } continue } require.Equal(t, CapsuleType(1337), ct) _, err = io.ReadAll(r) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } } func TestCapsuleWriting(t *testing.T) { var buf bytes.Buffer require.NoError(t, WriteCapsule(&buf, 1337, []byte("foobar"))) ct, r, err := ParseCapsule(&buf) require.NoError(t, err) require.Equal(t, CapsuleType(1337), ct) val, err := io.ReadAll(r) require.NoError(t, err) require.Equal(t, "foobar", string(val)) } func TestCapsuleWriteEmpty(t *testing.T) { var buf bytes.Buffer require.NoError(t, WriteCapsule(&buf, 1337, []byte{})) require.NoError(t, WriteCapsule(&buf, 1337, []byte{})) ct, r, err := ParseCapsule(&buf) require.NoError(t, err) require.Equal(t, CapsuleType(1337), ct) val, err := io.ReadAll(r) require.NoError(t, err) require.Empty(t, val) ct, r, err = ParseCapsule(&buf) require.NoError(t, err) require.Equal(t, CapsuleType(1337), ct) val, err = io.ReadAll(r) require.NoError(t, err) require.Empty(t, val) } quic-go-0.59.0/http3/client.go000066400000000000000000000360551513066070600160350ustar00rootroot00000000000000package http3 import ( "context" "errors" "fmt" "io" "log/slog" "net/http" "net/http/httptrace" "net/textproto" "sync" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" ) const ( // MethodGet0RTT allows a GET request to be sent using 0-RTT. // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests. MethodGet0RTT = "GET_0RTT" // MethodHead0RTT allows a HEAD request to be sent using 0-RTT. // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests. MethodHead0RTT = "HEAD_0RTT" ) const ( defaultUserAgent = "quic-go HTTP/3" defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB ) var errGoAway = errors.New("connection in graceful shutdown") type errConnUnusable struct{ e error } func (e *errConnUnusable) Unwrap() error { return e.e } func (e *errConnUnusable) Error() string { return fmt.Sprintf("http3: conn unusable: %s", e.e.Error()) } const max1xxResponses = 5 // arbitrary bound on number of informational responses var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams KeepAlivePeriod: 10 * time.Second, } // ClientConn is an HTTP/3 client doing requests to a single remote server. type ClientConn struct { conn *quic.Conn rawConn *rawConn decoder *qpack.Decoder // Additional HTTP/3 settings. // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). additionalSettings map[uint64]uint64 // maxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. maxResponseHeaderBytes int // disableCompression, if true, prevents the Transport from requesting compression with an // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value. // If the Transport requests gzip on its own and gets a gzipped response, it's transparently // decoded in the Response.Body. // However, if the user explicitly requested gzip it is not automatically uncompressed. disableCompression bool streamMx sync.Mutex maxStreamID quic.StreamID // set once a GOAWAY frame is received lastStreamID quic.StreamID // the highest stream ID that was opened qlogger qlogwriter.Recorder logger *slog.Logger requestWriter *requestWriter } var _ http.RoundTripper = &ClientConn{} func newClientConn( conn *quic.Conn, enableDatagrams bool, additionalSettings map[uint64]uint64, maxResponseHeaderBytes int, disableCompression bool, logger *slog.Logger, ) *ClientConn { var qlogger qlogwriter.Recorder if qlogTrace := conn.QlogTrace(); qlogTrace != nil && qlogTrace.SupportsSchemas(qlog.EventSchema) { qlogger = qlogTrace.AddProducer() } c := &ClientConn{ conn: conn, additionalSettings: additionalSettings, disableCompression: disableCompression, maxStreamID: invalidStreamID, lastStreamID: invalidStreamID, logger: logger, qlogger: qlogger, decoder: qpack.NewDecoder(), } if maxResponseHeaderBytes <= 0 { c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes } else { c.maxResponseHeaderBytes = maxResponseHeaderBytes } c.requestWriter = newRequestWriter() c.rawConn = newRawConn( conn, enableDatagrams, c.onStreamsEmpty, c.handleControlStream, qlogger, c.logger, ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { _, err := c.rawConn.openControlStream(&settingsFrame{ Datagram: enableDatagrams, Other: additionalSettings, MaxFieldSectionSize: int64(c.maxResponseHeaderBytes), }) if err != nil { if c.logger != nil { c.logger.Debug("setting up connection failed", "error", err) } c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") return } }() return c } // OpenRequestStream opens a new request stream on the HTTP/3 connection. func (c *ClientConn) OpenRequestStream(ctx context.Context) (*RequestStream, error) { return c.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes) } func (c *ClientConn) openRequestStream( ctx context.Context, requestWriter *requestWriter, reqDone chan<- struct{}, disableCompression bool, maxHeaderBytes int, ) (*RequestStream, error) { c.streamMx.Lock() maxStreamID := c.maxStreamID var nextStreamID quic.StreamID if c.lastStreamID == invalidStreamID { nextStreamID = 0 } else { nextStreamID = c.lastStreamID + 4 } c.streamMx.Unlock() // Streams with stream ID equal to or greater than the stream ID carried in the GOAWAY frame // will be rejected, see section 5.2 of RFC 9114. if maxStreamID != invalidStreamID && nextStreamID >= maxStreamID { return nil, errGoAway } str, err := c.conn.OpenStreamSync(ctx) if err != nil { return nil, err } c.streamMx.Lock() // take the maximum here, as multiple OpenStreamSync calls might have returned concurrently if c.lastStreamID == invalidStreamID { c.lastStreamID = str.StreamID() } else { c.lastStreamID = max(c.lastStreamID, str.StreamID()) } // check again, in case a (or another) GOAWAY frame was received maxStreamID = c.maxStreamID c.streamMx.Unlock() if maxStreamID != invalidStreamID && str.StreamID() >= maxStreamID { str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil, errGoAway } hstr := c.rawConn.TrackStream(str) rsp := &http.Response{} trace := httptrace.ContextClientTrace(ctx) return newRequestStream( newStream(hstr, c.rawConn, trace, func(r io.Reader, hf *headersFrame) error { hdr, err := decodeTrailers(r, hf, maxHeaderBytes, c.decoder, c.qlogger, str.StreamID()) if err != nil { return err } rsp.Trailer = hdr return nil }, c.qlogger), requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes, rsp, ), nil } func (c *ClientConn) handleUnidirectionalStream(str *quic.ReceiveStream) { c.rawConn.handleUnidirectionalStream(str, false) } func (c *ClientConn) handleControlStream(str *quic.ReceiveStream, fp *frameParser) { for { f, err := fp.ParseNext(c.qlogger) if err != nil { var serr *quic.StreamError if err == io.EOF || errors.As(err, &serr) { c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "") return } c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return } // GOAWAY is the only frame allowed at this point: // * unexpected frames are ignored by the frame parser // * we don't support any extension that might add support for more frames goaway, ok := f.(*goAwayFrame) if !ok { c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") return } if goaway.StreamID%4 != 0 { // client-initiated, bidirectional streams c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") return } c.streamMx.Lock() // the server is not allowed to increase the Stream ID in subsequent GOAWAY frames if c.maxStreamID != invalidStreamID && goaway.StreamID > c.maxStreamID { c.streamMx.Unlock() c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") return } c.maxStreamID = goaway.StreamID c.streamMx.Unlock() hasActiveStreams := c.rawConn.hasActiveStreams() // immediately close the connection if there are currently no active requests if !hasActiveStreams { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") return } } } func (c *ClientConn) onStreamsEmpty() { c.streamMx.Lock() defer c.streamMx.Unlock() // The server is performing a graceful shutdown. if c.maxStreamID != invalidStreamID { c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") } } // RoundTrip executes a request and returns a response func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { rsp, err := c.roundTrip(req) if err != nil && req.Context().Err() != nil { // if the context was canceled, return the context cancellation error err = req.Context().Err() } return rsp, err } func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) { // Immediately send out this request, if this is a 0-RTT request. switch req.Method { case MethodGet0RTT: // don't modify the original request reqCopy := *req req = &reqCopy req.Method = http.MethodGet case MethodHead0RTT: // don't modify the original request reqCopy := *req req = &reqCopy req.Method = http.MethodHead default: // wait for the handshake to complete select { case <-c.conn.HandshakeComplete(): case <-req.Context().Done(): return nil, req.Context().Err() } } // It is only possible to send an Extended CONNECT request once the SETTINGS were received. // See section 3 of RFC 8441. if isExtendedConnectRequest(req) { connCtx := c.conn.Context() // wait for the server's SETTINGS frame to arrive select { case <-c.rawConn.ReceivedSettings(): case <-connCtx.Done(): return nil, context.Cause(connCtx) } if !c.rawConn.Settings().EnableExtendedConnect { return nil, errors.New("http3: server didn't enable Extended CONNECT") } } reqDone := make(chan struct{}) str, err := c.openRequestStream( req.Context(), c.requestWriter, reqDone, c.disableCompression, c.maxResponseHeaderBytes, ) if err != nil { return nil, &errConnUnusable{e: err} } // Request Cancellation: // This go routine keeps running even after RoundTripOpt() returns. // It is shut down when the application is done processing the body. done := make(chan struct{}) go func() { defer close(done) select { case <-req.Context().Done(): str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) case <-reqDone: } }() rsp, err := c.doRequest(req, str) if err != nil { // if any error occurred close(reqDone) <-done return nil, maybeReplaceError(err) } return rsp, maybeReplaceError(err) } // ReceivedSettings returns a channel that is closed once the server's HTTP/3 settings were received. // Settings can be obtained from the Settings method after the channel was closed. func (c *ClientConn) ReceivedSettings() <-chan struct{} { return c.rawConn.ReceivedSettings() } // Settings returns the HTTP/3 settings for this connection. // It is only valid to call this function after the channel returned by ReceivedSettings was closed. func (c *ClientConn) Settings() *Settings { return c.rawConn.Settings() } // CloseWithError closes the connection with the given error code and message. // It is invalid to call this function after the connection was closed. func (c *ClientConn) CloseWithError(code quic.ApplicationErrorCode, msg string) error { return c.conn.CloseWithError(code, msg) } // Context returns a context that is cancelled when the connection is closed. func (c *ClientConn) Context() context.Context { return c.conn.Context() } // cancelingReader reads from the io.Reader. // It cancels writing on the stream if any error other than io.EOF occurs. type cancelingReader struct { r io.Reader str *RequestStream } func (r *cancelingReader) Read(b []byte) (int, error) { n, err := r.r.Read(b) if err != nil && err != io.EOF { r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) } return n, err } func (c *ClientConn) sendRequestBody(str *RequestStream, body io.ReadCloser, contentLength int64) error { defer body.Close() buf := make([]byte, bodyCopyBufferSize) sr := &cancelingReader{str: str, r: body} if contentLength == -1 { _, err := io.CopyBuffer(str, sr, buf) return err } // make sure we don't send more bytes than the content length n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf) if err != nil { return err } var extra int64 extra, err = io.CopyBuffer(io.Discard, sr, buf) n += extra if n > contentLength { str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n) } return err } func (c *ClientConn) doRequest(req *http.Request, str *RequestStream) (*http.Response, error) { trace := httptrace.ContextClientTrace(req.Context()) var sendingReqFailed bool if err := str.sendRequestHeader(req); err != nil { traceWroteRequest(trace, err) if c.logger != nil { c.logger.Debug("error writing request", "error", err) } sendingReqFailed = true } if !sendingReqFailed { if req.Body == nil { traceWroteRequest(trace, nil) str.Close() } else { // send the request body asynchronously go func() { defer str.Close() contentLength := int64(-1) // According to the documentation for http.Request.ContentLength, // a value of 0 with a non-nil Body is also treated as unknown content length. if req.ContentLength > 0 { contentLength = req.ContentLength } err := c.sendRequestBody(str, req.Body, contentLength) traceWroteRequest(trace, err) if err != nil { if c.logger != nil { c.logger.Debug("error writing request", "error", err) } return } if len(req.Trailer) > 0 { if err := str.sendRequestTrailer(req); err != nil { if c.logger != nil { c.logger.Debug("error writing trailers", "error", err) } } } }() } } // copy from net/http: support 1xx responses var num1xx int // number of informational 1xx headers received var res *http.Response for { var err error res, err = str.ReadResponse() if err != nil { return nil, err } resCode := res.StatusCode is1xx := 100 <= resCode && resCode <= 199 // treat 101 as a terminal status, see https://github.com/golang/go/issues/26161 is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols if is1xxNonTerminal { num1xx++ if num1xx > max1xxResponses { str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad)) str.CancelWrite(quic.StreamErrorCode(ErrCodeExcessiveLoad)) return nil, errors.New("http3: too many 1xx informational responses") } traceGot1xxResponse(trace, resCode, textproto.MIMEHeader(res.Header)) if resCode == http.StatusContinue { traceGot100Continue(trace) } continue } break } connState := c.conn.ConnectionState().TLS res.TLS = &connState res.Request = req return res, nil } // RawClientConn is a low-level HTTP/3 client connection. // It allows the application to take control of the stream accept loops, // giving the application the ability to handle streams originating from the server. type RawClientConn struct { *ClientConn } // HandleUnidirectionalStream handles an incoming unidirectional stream. func (c *RawClientConn) HandleUnidirectionalStream(str *quic.ReceiveStream) { c.rawConn.handleUnidirectionalStream(str, false) } // HandleBidirectionalStream handles an incoming bidirectional stream. func (c *ClientConn) HandleBidirectionalStream(str *quic.Stream) { // According to RFC 9114, the server is not allowed to open bidirectional streams. c.rawConn.CloseWithError( quic.ApplicationErrorCode(ErrCodeStreamCreationError), fmt.Sprintf("server opened bidirectional stream %d", str.StreamID()), ) } quic-go-0.59.0/http3/client_test.go000066400000000000000000000606511513066070600170730ustar00rootroot00000000000000package http3 import ( "bytes" "compress/gzip" "context" "io" mrand "math/rand/v2" "net/http" "net/http/httptest" "testing" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func TestClientSettings(t *testing.T) { t.Run("enable datagrams", func(t *testing.T) { testClientSettings(t, true, nil) }) t.Run("additional settings", func(t *testing.T) { testClientSettings(t, false, map[uint64]uint64{13: 37}) }) } func testClientSettings(t *testing.T, enableDatagrams bool, other map[uint64]uint64) { tr := &Transport{ EnableDatagrams: enableDatagrams, AdditionalSettings: other, } var eventRecorder events.Recorder clientConn, serverConn := newConnPair(t, withClientRecorder(&eventRecorder)) tr.NewClientConn(clientConn) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptUniStream(ctx) require.NoError(t, err) str.SetReadDeadline(time.Now().Add(time.Second)) typ, err := quicvarint.Read(quicvarint.NewReader(str)) require.NoError(t, err) require.EqualValues(t, streamTypeControlStream, typ) fp := (&frameParser{r: str}) f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f) settingsFrame := f.(*settingsFrame) require.Equal(t, settingsFrame.Datagram, enableDatagrams) require.Equal(t, settingsFrame.Other, other) var datagramValue *bool if enableDatagrams { datagramValue = pointer(true) } require.Equal(t, []qlogwriter.Event{ qlog.FrameCreated{ StreamID: str.StreamID(), Raw: qlog.RawInfo{Length: 10}, Frame: qlog.Frame{ Frame: qlog.SettingsFrame{ MaxFieldSectionSize: defaultMaxResponseHeaderBytes, Datagram: datagramValue, Other: other, }, }, }, }, filterQlogEventsForFrame(eventRecorder.Events(qlog.FrameCreated{}), qlog.SettingsFrame{}), ) } func encodeResponse(t *testing.T, status int) []byte { t.Helper() mockCtrl := gomock.NewController(t) buf := &bytes.Buffer{} rstr := NewMockDatagramStream(mockCtrl) rstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() rw := newResponseWriter(newStream(rstr, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), nil, false, nil) rw.WriteHeader(status) rw.Flush() return buf.Bytes() } func TestClientRequest(t *testing.T) { t.Run("GET", func(t *testing.T) { rsp := testClientRequest(t, false, http.MethodGet, encodeResponse(t, http.StatusTeapot)) require.Equal(t, http.StatusTeapot, rsp.StatusCode) require.Equal(t, "HTTP/3.0", rsp.Proto) require.Equal(t, 3, rsp.ProtoMajor) require.NotNil(t, rsp.Request) }) t.Run("GET 0-RTT", func(t *testing.T) { rsp := testClientRequest(t, true, http.MethodGet, encodeResponse(t, http.StatusOK)) require.Equal(t, http.StatusOK, rsp.StatusCode) }) t.Run("HEAD", func(t *testing.T) { rsp := testClientRequest(t, false, http.MethodHead, encodeResponse(t, http.StatusTeapot)) require.Equal(t, http.StatusTeapot, rsp.StatusCode) }) t.Run("HEAD 0-RTT", func(t *testing.T) { rsp := testClientRequest(t, true, http.MethodHead, encodeResponse(t, http.StatusOK)) require.Equal(t, http.StatusOK, rsp.StatusCode) }) } func testClientRequest(t *testing.T, use0RTT bool, method string, rspBytes []byte) *http.Response { clientConn, serverConn := newConnPair(t) reqMethod := method if use0RTT { switch method { case http.MethodGet: reqMethod = MethodGet0RTT case http.MethodHead: reqMethod = MethodHead0RTT } } req, err := http.NewRequest(reqMethod, "http://quic-go.net", nil) require.NoError(t, err) type result struct { rsp *http.Response err error } resultChan := make(chan result, 1) go func() { cc := (&Transport{}).NewClientConn(clientConn) rsp, err := cc.RoundTrip(req) resultChan <- result{rsp: rsp, err: err} }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) str.SetReadDeadline(time.Now().Add(time.Second)) hfs := decodeHeader(t, str) require.Equal(t, []string{method}, hfs[":method"]) _, err = str.Write(rspBytes) require.NoError(t, err) var res result select { case res = <-resultChan: require.NoError(t, res.err) case <-time.After(time.Second): t.Fatal("timeout") } // make sure the http.Request.Method value was not modified if use0RTT { switch reqMethod { case MethodGet0RTT: require.Equal(t, req.Method, MethodGet0RTT) case MethodHead0RTT: require.Equal(t, req.Method, MethodHead0RTT) } } return res.rsp } func randomString(length int) string { const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, length) for i := range b { n := mrand.IntN(len(alphabet)) b[i] = alphabet[n] } return string(b) } func TestClientRequestError(t *testing.T) { clientConn, serverConn := newConnPair(t) req, err := http.NewRequest(http.MethodGet, "http://quic-go.net", nil) require.NoError(t, err) for range 1000 { req.Header.Add(randomString(50), randomString(50)) } type result struct { rsp *http.Response err error } resultChan := make(chan result, 1) go func() { cc := (&Transport{}).NewClientConn(clientConn) rsp, err := cc.RoundTrip(req) resultChan <- result{rsp: rsp, err: err} }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad)) _, err = str.Write(encodeResponse(t, http.StatusTeapot)) require.NoError(t, err) var res result select { case res = <-resultChan: require.NoError(t, res.err) require.Equal(t, http.StatusTeapot, res.rsp.StatusCode) case <-time.After(time.Second): t.Fatal("timeout") } } func TestClientResponseValidation(t *testing.T) { t.Run("HEADERS frame too large", func(t *testing.T) { require.ErrorContains(t, testClientResponseValidation(t, &Transport{MaxResponseHeaderBytes: 1337}, (&headersFrame{Length: 1338}).Append(nil), quic.StreamErrorCode(ErrCodeFrameError), ), "http3: HEADERS frame too large", ) }) t.Run("invalid headers", func(t *testing.T) { headerBuf := &bytes.Buffer{} enc := qpack.NewEncoder(headerBuf) // not a valid response pseudo header require.NoError(t, enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})) require.NoError(t, enc.Close()) b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) b = append(b, headerBuf.Bytes()...) require.ErrorContains(t, testClientResponseValidation(t, &Transport{}, b, quic.StreamErrorCode(ErrCodeMessageError)), "invalid response pseudo header", ) }) } func testClientResponseValidation(t *testing.T, tr *Transport, rsp []byte, expectedReset quic.StreamErrorCode) error { clientConn, serverConn := newConnPair(t) cc := tr.NewClientConn(clientConn) errChan := make(chan error) go func() { _, err := cc.RoundTrip(httptest.NewRequest(http.MethodGet, "http://quic-go.net", nil)) errChan <- err }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) _, err = str.Write(rsp) require.NoError(t, err) select { case err := <-errChan: expectStreamWriteReset(t, str, expectedReset) // The client closes the stream after sending the request, // so we need to wait for the RESET_STREAM frame to be received. time.Sleep(scaleDuration(10 * time.Millisecond)) expectStreamReadReset(t, str, expectedReset) return err case <-time.After(time.Second): t.Fatal("timeout") } panic("unreachable") } func TestClientRequestLengthLimit(t *testing.T) { clientConn, serverConn := newConnPair(t) cc := (&Transport{}).NewClientConn(clientConn) errChan := make(chan error) body := bytes.NewBufferString("request body") go func() { req := httptest.NewRequest(http.MethodPost, "http://quic-go.net", body) req.ContentLength = 8 _, err := cc.RoundTrip(req) errChan <- err }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) _, err = io.ReadAll(str) var strErr *quic.StreamError require.ErrorAs(t, err, &strErr) require.Equal(t, quic.StreamErrorCode(ErrCodeRequestCanceled), strErr.ErrorCode) _, err = str.Write(encodeResponse(t, http.StatusTeapot)) require.NoError(t, err) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestClientExtendedConnect(t *testing.T) { t.Run("enabled", func(t *testing.T) { testClientExtendedConnect(t, true) }) t.Run("disabled", func(t *testing.T) { testClientExtendedConnect(t, false) }) } func testClientExtendedConnect(t *testing.T, enabled bool) { clientConn, serverConn := newConnPair(t) cc := (&Transport{}).NewClientConn(clientConn) req, err := http.NewRequest(http.MethodConnect, "http://quic-go.net", nil) require.NoError(t, err) req.Proto = "connect" errChan := make(chan error) go func() { _, err := cc.RoundTrip(req) errChan <- err }() select { case <-errChan: t.Fatal("RoundTrip should have blocked until SETTINGS were received") case <-time.After(scaleDuration(10 * time.Millisecond)): } // now send the SETTINGS settingsStr, err := serverConn.OpenUniStream() require.NoError(t, err) settingsStr.SetWriteDeadline(time.Now().Add(time.Second)) settingsFrame := &settingsFrame{ExtendedConnect: enabled} _, err = settingsStr.Write(settingsFrame.Append(quicvarint.Append(nil, streamTypeControlStream))) require.NoError(t, err) select { case <-cc.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("timeout waiting for settings") } settings := cc.Settings() require.Equal(t, enabled, settings.EnableExtendedConnect) if enabled { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) str.CancelRead(1337) str.CancelWrite(1337) } select { case err := <-errChan: if enabled { require.ErrorIs(t, err, &Error{Remote: true, ErrorCode: 1337}) } else { require.EqualError(t, err, "http3: server didn't enable Extended CONNECT") } case <-time.After(time.Second): t.Fatal("timeout") } } func TestClient1xxHandling(t *testing.T) { t.Run("a few early hints", func(t *testing.T) { testClient1xxHandling(t, max1xxResponses, http.StatusOK, false) }) t.Run("too many early hints", func(t *testing.T) { testClient1xxHandling(t, max1xxResponses+1, http.StatusOK, true) }) t.Run("EarlyHints followed by StatusSwitchingProtocols", func(t *testing.T) { testClient1xxHandling(t, 1, http.StatusSwitchingProtocols, false) }) } func testClient1xxHandling(t *testing.T, numEarlyHints int, terminalStatus int, tooMany bool) { var rspBuf bytes.Buffer rstr := NewMockDatagramStream(gomock.NewController(t)) rstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() rstr.EXPECT().Write(gomock.Any()).Do(rspBuf.Write).AnyTimes() rw := newResponseWriter(newStream(rstr, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), nil, false, nil) rw.header.Add("Link", "foo") rw.header.Add("Link", "bar") for range numEarlyHints { rw.WriteHeader(http.StatusEarlyHints) } rw.WriteHeader(terminalStatus) rw.Flush() rspBytes := rspBuf.Bytes() clientConn, serverConn := newConnPair(t) type result struct { rsp *http.Response err error } resultChan := make(chan result, 1) go func() { cc := (&Transport{}).NewClientConn(clientConn) rsp, err := cc.RoundTrip(httptest.NewRequest(http.MethodGet, "http://quic-go.net", nil)) resultChan <- result{rsp: rsp, err: err} }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) // request headers hfs := decodeHeader(t, str) require.Equal(t, hfs[":method"], []string{http.MethodGet}) _, err = str.Write(rspBytes) require.NoError(t, err) var rsp *http.Response select { case res := <-resultChan: if tooMany { require.EqualError(t, res.err, "http3: too many 1xx informational responses") return } require.NoError(t, res.err) rsp = res.rsp case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, []string{"foo", "bar"}, rsp.Header["Link"]) require.Equal(t, terminalStatus, rsp.StatusCode) } func TestClientGzip(t *testing.T) { var buf bytes.Buffer w := gzip.NewWriter(&buf) w.Write([]byte("foobar")) w.Close() gzippedFoobar := buf.Bytes() t.Run("gzipped", func(t *testing.T) { testClientGzip(t, gzippedFoobar, []byte("foobar"), false, true) }) t.Run("not gzipped", func(t *testing.T) { testClientGzip(t, []byte("foobar"), []byte("foobar"), false, false) }) t.Run("disable compression", func(t *testing.T) { testClientGzip(t, gzippedFoobar, gzippedFoobar, true, true) }) } func testClientGzip(t *testing.T, data []byte, expectedRsp []byte, transportDisableCompression bool, responseAddContentEncoding bool, ) { var rspBuf bytes.Buffer rstr := NewMockDatagramStream(gomock.NewController(t)) rstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() rstr.EXPECT().Write(gomock.Any()).Do(rspBuf.Write).AnyTimes() rw := newResponseWriter(newStream(rstr, nil, nil, func(io.Reader, *headersFrame) error { return nil }, nil), nil, false, nil) rw.WriteHeader(http.StatusOK) if responseAddContentEncoding { rw.header.Add("Content-Encoding", "gzip") } rw.Write(data) rw.Flush() clientConn, serverConn := newConnPair(t) type result struct { rsp *http.Response err error } resultChan := make(chan result) go func() { cc := (&Transport{DisableCompression: transportDisableCompression}).NewClientConn(clientConn) rsp, err := cc.RoundTrip(httptest.NewRequest(http.MethodGet, "http://quic-go.net", nil)) resultChan <- result{rsp: rsp, err: err} }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) // request headers str.SetReadDeadline(time.Now().Add(time.Second)) hfs := decodeHeader(t, str) if transportDisableCompression { require.NotContains(t, hfs, "accept-encoding") } else { require.Equal(t, hfs["accept-encoding"], []string{"gzip"}) } _, err = str.Write(rspBuf.Bytes()) require.NoError(t, err) require.NoError(t, str.Close()) var rsp *http.Response select { case res := <-resultChan: require.NoError(t, res.err) rsp = res.rsp case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, http.StatusOK, rsp.StatusCode) body, err := io.ReadAll(rsp.Body) require.NoError(t, err) require.Equal(t, expectedRsp, body) } func TestClientRequestCancellation(t *testing.T) { clientConn, serverConn := newConnPair(t) requestCtx, requestCancel := context.WithCancel(context.Background()) req, err := http.NewRequestWithContext(requestCtx, http.MethodGet, "http://quic-go.net", nil) require.NoError(t, err) type result struct { rsp *http.Response err error } resultChan := make(chan result) go func() { cc := (&Transport{}).NewClientConn(clientConn) rsp, err := cc.RoundTrip(req) resultChan <- result{rsp: rsp, err: err} }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := serverConn.AcceptStream(ctx) require.NoError(t, err) _, err = str.Write(encodeResponse(t, http.StatusTeapot)) require.NoError(t, err) select { case res := <-resultChan: require.NoError(t, res.err) require.Equal(t, http.StatusTeapot, res.rsp.StatusCode) case <-time.After(time.Second): t.Fatal("timeout") } requestCancel() expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestCanceled)) } func TestClientConnGoAway(t *testing.T) { t.Run("no active streams", func(t *testing.T) { testClientConnGoAway(t, false) }) t.Run("active stream", func(t *testing.T) { testClientConnGoAway(t, true) }) } func testClientConnGoAway(t *testing.T, withStream bool) { var clientEventRecorder events.Recorder clientConn, serverConn := newConnPair(t, withClientRecorder(&clientEventRecorder)) cc := (&Transport{}).NewClientConn(clientConn) var str *RequestStream if withStream { s, err := cc.OpenRequestStream(context.Background()) require.NoError(t, err) str = s } // server sends control stream with SETTINGS and GOAWAY b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) b = (&goAwayFrame{StreamID: 8}).Append(b) controlStr, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(b) require.NoError(t, err) // the connection should be closed after the stream is closed if withStream { select { case <-serverConn.Context().Done(): t.Fatal("connection closed") case <-time.After(scaleDuration(10 * time.Millisecond)): } // the stream ID in the GOAWAY frame is 8, so it's possible to open stream 4 str2, err := cc.OpenRequestStream(context.Background()) require.NoError(t, err) str2.Close() str2.CancelRead(1337) // it's not possible to open stream 8 _, err = cc.OpenRequestStream(context.Background()) require.ErrorIs(t, err, errGoAway) str.Close() str.CancelRead(1337) } select { case <-serverConn.Context().Done(): require.ErrorIs(t, context.Cause(serverConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for close") } expectedLen, expectedPayloadLen := expectedFrameLength(t, &goAwayFrame{StreamID: 8}) require.Equal(t, []qlogwriter.Event{ qlog.FrameParsed{ StreamID: controlStr.StreamID(), Raw: qlog.RawInfo{PayloadLength: expectedPayloadLen, Length: expectedLen}, Frame: qlog.Frame{Frame: qlog.GoAwayFrame{StreamID: 8}}, }, }, filterQlogEventsForFrame(clientEventRecorder.Events(qlog.FrameParsed{}), qlog.GoAwayFrame{StreamID: 8}), ) } func TestClientConnGoConcurrent(t *testing.T) { clientConn, serverConn := newConnPair(t, withServerBidiStreamLimit(1)) // allows streams 0 cc := (&Transport{}).NewClientConn(clientConn) // peer sends control stream with SETTINGS, but not GOAWAY yet b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) controlStr, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(b) require.NoError(t, err) select { case <-serverConn.Context().Done(): t.Fatal("connection closed") case <-time.After(scaleDuration(10 * time.Millisecond)): } // of these 2 OpenStreamSync calls, one will succeed, the other one will block errChan := make(chan error, 3) for range 2 { go func() { str, err := cc.OpenRequestStream(context.Background()) if err == nil { str.Close() } errChan <- err }() } // wait until all Goroutines have started time.Sleep(scaleDuration(10 * time.Millisecond)) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } // the second stream is still blocked select { case <-errChan: t.Fatal("second OpenStreamSync should have blocked") case <-time.After(scaleDuration(10 * time.Millisecond)): } // send the GOAWAY frame b = (&goAwayFrame{StreamID: 4}).Append(nil) _, err = controlStr.Write(b) require.NoError(t, err) // accepting and closing the stream allows the client to open another stream ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() sstr, err := serverConn.AcceptStream(ctx) require.NoError(t, err) sstr.Close() sstr.CancelRead(1337) // The second stream is opened by the client, // and immediately closed with a H3_REQUEST_CANCELED error. select { case err := <-errChan: require.ErrorIs(t, err, errGoAway) case <-time.After(scaleDuration(10 * time.Millisecond)): t.Fatal("timeout") } sstr, err = serverConn.AcceptStream(ctx) require.NoError(t, err) _, err = sstr.Read([]byte{0}) require.ErrorIs(t, err, &quic.StreamError{StreamID: 4, ErrorCode: quic.StreamErrorCode(ErrCodeRequestCanceled), Remote: true}) } func TestClientConnGoAwayFailures(t *testing.T) { t.Run("invalid frame", func(t *testing.T) { b := (&settingsFrame{}).Append(nil) // 1337 is invalid value for the Extended CONNECT setting b = (&settingsFrame{Other: map[uint64]uint64{settingExtendedConnect: 1337}}).Append(b) testClientConnGoAwayFailures(t, b, nil, ErrCodeFrameError) }) t.Run("not a GOAWAY", func(t *testing.T) { b := (&settingsFrame{}).Append(nil) // GOAWAY is the only allowed frame type after SETTINGS b = (&headersFrame{}).Append(b) testClientConnGoAwayFailures(t, b, nil, ErrCodeFrameUnexpected) }) t.Run("stream closed before GOAWAY", func(t *testing.T) { testClientConnGoAwayFailures(t, (&settingsFrame{}).Append(nil), io.EOF, ErrCodeClosedCriticalStream) }) t.Run("stream reset before GOAWAY", func(t *testing.T) { testClientConnGoAwayFailures(t, (&settingsFrame{}).Append(nil), &quic.StreamError{Remote: true, ErrorCode: 42}, ErrCodeClosedCriticalStream, ) }) t.Run("invalid stream ID", func(t *testing.T) { data := (&settingsFrame{}).Append(nil) data = (&goAwayFrame{StreamID: 1}).Append(data) testClientConnGoAwayFailures(t, data, nil, ErrCodeIDError) }) t.Run("increased stream ID", func(t *testing.T) { localConn, peerConn := newConnPair(t) cc := (&Transport{}).NewClientConn(localConn) // need an active stream so the connection doesn't close after the first GOAWAY _, err := cc.OpenRequestStream(context.Background()) require.NoError(t, err) controlStr, err := peerConn.OpenUniStream() require.NoError(t, err) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) b = (&goAwayFrame{StreamID: 4}).Append(b) b = (&goAwayFrame{StreamID: 8}).Append(b) _, err = controlStr.Write(b) require.NoError(t, err) select { case <-peerConn.Context().Done(): require.ErrorIs(t, context.Cause(peerConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeIDError)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for close") } }) } func testClientConnGoAwayFailures(t *testing.T, data []byte, readErr error, expectedErr ErrCode) { localConn, peerConn := newConnPair(t) (&Transport{}).NewClientConn(localConn) controlStr, err := peerConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(quicvarint.Append(nil, streamTypeControlStream)) require.NoError(t, err) switch readErr { case nil: _, err = controlStr.Write(data) require.NoError(t, err) case io.EOF: _, err = controlStr.Write(data) require.NoError(t, err) require.NoError(t, controlStr.Close()) default: // make sure the stream type is received time.Sleep(scaleDuration(10 * time.Millisecond)) controlStr.CancelWrite(1337) } select { case <-peerConn.Context().Done(): require.ErrorIs(t, context.Cause(peerConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for close") } } func TestClientConnHandleBidirectionalStream(t *testing.T) { clientConn, serverConn := newConnPair(t) cc := (&Transport{}).NewClientConn(clientConn) str, err := clientConn.OpenStream() require.NoError(t, err) cc.HandleBidirectionalStream(str) select { case <-serverConn.Context().Done(): require.ErrorIs(t, context.Cause(serverConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeStreamCreationError)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for connection close") } } func TestRawClientConnHandleUnidirectionalStream(t *testing.T) { clientConn, serverConn := newConnPair(t) cc := (&Transport{}).NewRawClientConn(clientConn) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) str, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = str.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() uniStr, err := clientConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) cc.HandleUnidirectionalStream(uniStr) }() select { case <-cc.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("timeout waiting for settings") } require.NotNil(t, cc.Settings()) } quic-go-0.59.0/http3/conn.go000066400000000000000000000221571513066070600155120ustar00rootroot00000000000000package http3 import ( "context" "errors" "fmt" "io" "log/slog" "maps" "net" "sync" "sync/atomic" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" ) const maxQuarterStreamID = 1<<60 - 1 // invalidStreamID is a stream ID that is invalid. The first valid stream ID in QUIC is 0. const invalidStreamID = quic.StreamID(-1) // rawConn is an HTTP/3 connection. // It provides HTTP/3 specific functionality by wrapping a quic.Conn, // in particular handling of unidirectional HTTP/3 streams, SETTINGS and datagrams. type rawConn struct { conn *quic.Conn logger *slog.Logger enableDatagrams bool streamMx sync.Mutex streams map[quic.StreamID]*stateTrackingStream rcvdControlStr atomic.Bool rcvdQPACKEncoderStr atomic.Bool rcvdQPACKDecoderStr atomic.Bool controlStrHandler func(*quic.ReceiveStream, *frameParser) // is called *after* the SETTINGS frame was parsed onStreamsEmpty func() settings *Settings receivedSettings chan struct{} qlogger qlogwriter.Recorder qloggerWG sync.WaitGroup // tracks goroutines that may produce qlog events } func newRawConn( quicConn *quic.Conn, enableDatagrams bool, onStreamsEmpty func(), controlStrHandler func(*quic.ReceiveStream, *frameParser), qlogger qlogwriter.Recorder, logger *slog.Logger, ) *rawConn { c := &rawConn{ conn: quicConn, logger: logger, enableDatagrams: enableDatagrams, receivedSettings: make(chan struct{}), streams: make(map[quic.StreamID]*stateTrackingStream), qlogger: qlogger, onStreamsEmpty: onStreamsEmpty, controlStrHandler: controlStrHandler, } if qlogger != nil { context.AfterFunc(quicConn.Context(), c.closeQlogger) } return c } func (c *rawConn) OpenUniStream() (*quic.SendStream, error) { return c.conn.OpenUniStream() } // openControlStream opens the control stream and sends the SETTINGS frame. // It returns the control stream (needed by the server for sending GOAWAY later). func (c *rawConn) openControlStream(settings *settingsFrame) (*quic.SendStream, error) { c.qloggerWG.Add(1) defer c.qloggerWG.Done() str, err := c.conn.OpenUniStream() if err != nil { return nil, err } b := make([]byte, 0, 64) b = quicvarint.Append(b, streamTypeControlStream) b = settings.Append(b) if c.qlogger != nil { sf := qlog.SettingsFrame{ MaxFieldSectionSize: settings.MaxFieldSectionSize, Other: maps.Clone(settings.Other), } if settings.Datagram { sf.Datagram = pointer(true) } if settings.ExtendedConnect { sf.ExtendedConnect = pointer(true) } c.qlogger.RecordEvent(qlog.FrameCreated{ StreamID: str.StreamID(), Raw: qlog.RawInfo{Length: len(b)}, Frame: qlog.Frame{Frame: sf}, }) } if _, err := str.Write(b); err != nil { return nil, err } return str, nil } func (c *rawConn) TrackStream(str *quic.Stream) *stateTrackingStream { hstr := newStateTrackingStream(str, c, func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) c.streamMx.Lock() c.streams[str.StreamID()] = hstr c.qloggerWG.Add(1) c.streamMx.Unlock() return hstr } func (c *rawConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } func (c *rawConn) ConnectionState() quic.ConnectionState { return c.conn.ConnectionState() } func (c *rawConn) clearStream(id quic.StreamID) { c.streamMx.Lock() defer c.streamMx.Unlock() if _, ok := c.streams[id]; ok { delete(c.streams, id) c.qloggerWG.Done() } if len(c.streams) == 0 { c.onStreamsEmpty() } } func (c *rawConn) hasActiveStreams() bool { c.streamMx.Lock() defer c.streamMx.Unlock() return len(c.streams) > 0 } func (c *rawConn) CloseWithError(code quic.ApplicationErrorCode, msg string) error { return c.conn.CloseWithError(code, msg) } func (c *rawConn) handleUnidirectionalStream(str *quic.ReceiveStream, isServer bool) { c.qloggerWG.Add(1) defer c.qloggerWG.Done() streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { if c.logger != nil { c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) } return } // We're only interested in the control stream here. switch streamType { case streamTypeControlStream: case streamTypeQPACKEncoderStream: if isFirst := c.rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") } // Our QPACK implementation doesn't use the dynamic table yet. return case streamTypeQPACKDecoderStream: if isFirst := c.rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") } // Our QPACK implementation doesn't use the dynamic table yet. return case streamTypePushStream: if isServer { // only the server can push c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") } else { // we never increased the Push ID, so we don't expect any push streams c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") } return default: str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) return } // Only a single control stream is allowed. if isFirstControlStr := c.rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") return } c.handleControlStream(str) } func (c *rawConn) handleControlStream(str *quic.ReceiveStream) { fp := &frameParser{closeConn: c.conn.CloseWithError, r: str, streamID: str.StreamID()} f, err := fp.ParseNext(c.qlogger) if err != nil { var serr *quic.StreamError if err == io.EOF || errors.As(err, &serr) { c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "") return } c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") return } c.settings = &Settings{ EnableDatagrams: sf.Datagram, EnableExtendedConnect: sf.ExtendedConnect, Other: sf.Other, } close(c.receivedSettings) if sf.Datagram { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams.Remote { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") return } c.qloggerWG.Add(1) go func() { defer c.qloggerWG.Done() if err := c.receiveDatagrams(); err != nil { if c.logger != nil { c.logger.Debug("receiving datagrams failed", "error", err) } } }() } if c.controlStrHandler != nil { c.controlStrHandler(str, fp) } } func (c *rawConn) sendDatagram(streamID quic.StreamID, b []byte) error { // TODO: this creates a lot of garbage and an additional copy data := make([]byte, 0, len(b)+8) quarterStreamID := uint64(streamID / 4) data = quicvarint.Append(data, uint64(streamID/4)) data = append(data, b...) if c.qlogger != nil { c.qlogger.RecordEvent(qlog.DatagramCreated{ QuaterStreamID: quarterStreamID, Raw: qlog.RawInfo{ Length: len(data), PayloadLength: len(b), }, }) } return c.conn.SendDatagram(data) } func (c *rawConn) receiveDatagrams() error { for { b, err := c.conn.ReceiveDatagram(context.Background()) if err != nil { return err } quarterStreamID, n, err := quicvarint.Parse(b) if err != nil { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") return fmt.Errorf("could not read quarter stream id: %w", err) } if c.qlogger != nil { c.qlogger.RecordEvent(qlog.DatagramParsed{ QuaterStreamID: quarterStreamID, Raw: qlog.RawInfo{ Length: len(b), PayloadLength: len(b) - n, }, }) } if quarterStreamID > maxQuarterStreamID { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") return fmt.Errorf("invalid quarter stream id: %w", err) } streamID := quic.StreamID(4 * quarterStreamID) c.streamMx.Lock() dg, ok := c.streams[streamID] c.streamMx.Unlock() if !ok { continue } dg.enqueueDatagram(b[n:]) } } // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. // Settings can be optained from the Settings method after the channel was closed. func (c *rawConn) ReceivedSettings() <-chan struct{} { return c.receivedSettings } // Settings returns the settings received on this connection. // It is only valid to call this function after the channel returned by ReceivedSettings was closed. func (c *rawConn) Settings() *Settings { return c.settings } // closeQlogger waits for all goroutines that may produce qlog events to finish, // then closes the qlogger. func (c *rawConn) closeQlogger() { if c.qlogger == nil { return } c.qloggerWG.Wait() c.qlogger.Close() } quic-go-0.59.0/http3/conn_test.go000066400000000000000000000344651513066070600165560ustar00rootroot00000000000000package http3 import ( "bytes" "context" "io" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func TestConnReceiveSettings(t *testing.T) { var eventRecorder events.Recorder clientConn, serverConn := newConnPair(t, withServerRecorder(&eventRecorder)) conn := newRawConn(serverConn, false, nil, nil, &eventRecorder, nil) b := quicvarint.Append(nil, streamTypeControlStream) sf := &settingsFrame{ MaxFieldSectionSize: 1234, Datagram: true, ExtendedConnect: true, Other: map[uint64]uint64{1337: 42}, } b = sf.Append(b) controlStr, err := clientConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() serverStr, err := serverConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(serverStr, true) }() select { case <-conn.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("timeout waiting for settings") } settings := conn.Settings() require.True(t, settings.EnableDatagrams) require.True(t, settings.EnableExtendedConnect) require.Equal(t, map[uint64]uint64{1337: 42}, settings.Other) expectedLen, expectedPayloadLen := expectedFrameLength(t, sf) require.Equal(t, []qlogwriter.Event{ qlog.FrameParsed{ StreamID: controlStr.StreamID(), Raw: qlog.RawInfo{Length: expectedLen, PayloadLength: expectedPayloadLen}, Frame: qlog.Frame{ Frame: qlog.SettingsFrame{ MaxFieldSectionSize: 1234, Datagram: pointer(true), ExtendedConnect: pointer(true), Other: map[uint64]uint64{1337: 42}, }, }, }, }, filterQlogEventsForFrame(eventRecorder.Events(qlog.FrameParsed{}), qlog.SettingsFrame{}), ) } func TestConnRejectDuplicateStreams(t *testing.T) { t.Run("control stream", func(t *testing.T) { testConnRejectDuplicateStreams(t, streamTypeControlStream) }) t.Run("encoder stream", func(t *testing.T) { testConnRejectDuplicateStreams(t, streamTypeQPACKEncoderStream) }) t.Run("decoder stream", func(t *testing.T) { testConnRejectDuplicateStreams(t, streamTypeQPACKDecoderStream) }) } func testConnRejectDuplicateStreams(t *testing.T, typ uint64) { clientConn, serverConn := newConnPair(t) conn := newRawConn(serverConn, false, nil, nil, nil, nil) b := quicvarint.Append(nil, typ) if typ == streamTypeControlStream { b = (&settingsFrame{}).Append(b) } controlStr1, err := clientConn.OpenUniStream() require.NoError(t, err) _, err = controlStr1.Write(b) require.NoError(t, err) controlStr2, err := clientConn.OpenUniStream() require.NoError(t, err) _, err = controlStr2.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() serverStr1, err := serverConn.AcceptUniStream(ctx) require.NoError(t, err) serverStr2, err := serverConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}, 2) go func() { defer func() { done <- struct{}{} }() conn.handleUnidirectionalStream(serverStr1, true) }() go func() { defer func() { done <- struct{}{} }() conn.handleUnidirectionalStream(serverStr2, true) }() select { case <-clientConn.Context().Done(): require.ErrorIs(t, context.Cause(clientConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeStreamCreationError)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for duplicate stream") } for range 2 { select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } } func TestConnResetUnknownUniStream(t *testing.T) { clientConn, serverConn := newConnPair(t) conn := newRawConn(serverConn, false, nil, nil, nil, nil) buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) str, err := clientConn.OpenUniStream() require.NoError(t, err) _, err = str.Write(buf.Bytes()) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() serverStr, err := serverConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(serverStr, true) }() expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeStreamCreationError)) select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnControlStreamFailures(t *testing.T) { t.Run("missing SETTINGS", func(t *testing.T) { testConnControlStreamFailures(t, (&dataFrame{}).Append(nil), nil, ErrCodeMissingSettings) }) t.Run("frame error", func(t *testing.T) { testConnControlStreamFailures(t, // 1337 is invalid value for the Extended CONNECT setting (&settingsFrame{Other: map[uint64]uint64{settingExtendedConnect: 1337}}).Append(nil), nil, ErrCodeFrameError, ) }) t.Run("control stream closed before SETTINGS", func(t *testing.T) { testConnControlStreamFailures(t, nil, io.EOF, ErrCodeClosedCriticalStream) }) t.Run("control stream reset before SETTINGS", func(t *testing.T) { testConnControlStreamFailures(t, nil, &quic.StreamError{Remote: true, ErrorCode: 42}, ErrCodeClosedCriticalStream, ) }) } func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, expectedErr ErrCode) { clientConn, serverConn := newConnPair(t) conn := newRawConn(clientConn, false, nil, nil, nil, nil) controlStr, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(quicvarint.Append(nil, streamTypeControlStream)) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientStr, err := clientConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(clientStr, false) }() switch readErr { case nil: _, err = controlStr.Write(data) require.NoError(t, err) case io.EOF: _, err = controlStr.Write(data) require.NoError(t, err) require.NoError(t, controlStr.Close()) default: // make sure the stream type is received time.Sleep(scaleDuration(10 * time.Millisecond)) controlStr.CancelWrite(1337) } select { case <-serverConn.Context().Done(): require.ErrorIs(t, context.Cause(serverConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for close") } select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnControlStreamHandler(t *testing.T) { t.Run("with handler", func(t *testing.T) { testConnControlStreamHandler(t, true) }) t.Run("without handler", func(t *testing.T) { testConnControlStreamHandler(t, false) }) } func testConnControlStreamHandler(t *testing.T, useHandler bool) { localConn, peerConn := newConnPair(t) handlerCalled := make(chan struct{}) var controlStrHandler func(*quic.ReceiveStream, *frameParser) if useHandler { controlStrHandler = func(*quic.ReceiveStream, *frameParser) { close(handlerCalled) } } conn := newRawConn(localConn, false, nil, controlStrHandler, nil, nil) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) str, err := peerConn.OpenUniStream() require.NoError(t, err) _, err = str.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() localStr, err := localConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(localStr, false) }() select { case <-conn.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("timeout waiting for settings") } if useHandler { select { case <-handlerCalled: case <-time.After(time.Second): t.Fatal("timeout waiting for handler to be called") } } else { select { case <-done: case <-time.After(time.Second): t.Fatal("timeout waiting for handler to return") } } } func TestConnRejectPushStream(t *testing.T) { t.Run("client", func(t *testing.T) { testConnRejectPushStream(t, false, ErrCodeIDError) }) t.Run("server", func(t *testing.T) { testConnRejectPushStream(t, true, ErrCodeStreamCreationError) }) } func testConnRejectPushStream(t *testing.T, isServer bool, expectedErr ErrCode) { localConn, peerConn := newConnPair(t) conn := newRawConn(localConn, false, nil, nil, nil, nil) buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) str, err := peerConn.OpenUniStream() require.NoError(t, err) _, err = str.Write(buf.Bytes()) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() localStr, err := localConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(localStr, isServer) }() select { case <-peerConn.Context().Done(): require.ErrorIs(t, context.Cause(peerConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for close") } select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } func TestConnInconsistentDatagramSupport(t *testing.T) { clientConn, serverConn := newConnPair(t) conn := newRawConn(clientConn, true, nil, nil, nil, nil) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) controlStr, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientStr, err := clientConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(clientStr, false) }() select { case <-serverConn.Context().Done(): err := context.Cause(serverConn.Context()) require.ErrorIs(t, err, &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeSettingsError)}) require.ErrorContains(t, err, "missing QUIC Datagram support") case <-time.After(time.Second): t.Fatal("timeout waiting for close") } } func TestConnSendAndReceiveDatagram(t *testing.T) { var eventRecorder events.Recorder clientConn, serverConn := newConnPair(t, withDatagrams(), withClientRecorder(&eventRecorder)) conn := newRawConn(clientConn, true, nil, nil, &eventRecorder, nil) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) controlStr, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientStr, err := clientConn.AcceptUniStream(ctx) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) conn.handleUnidirectionalStream(clientStr, false) }() const strID = 4 // first deliver a datagram... // since the stream is not open yet, it will be dropped quarterStreamID := quicvarint.Append([]byte{}, strID/4) datagram := append(quarterStreamID, []byte("foo")...) require.NoError(t, serverConn.SendDatagram(datagram)) time.Sleep(scaleDuration(10 * time.Millisecond)) // give the datagram a chance to be delivered require.Equal(t, []qlogwriter.Event{ qlog.DatagramParsed{ QuaterStreamID: strID / 4, Raw: qlog.RawInfo{Length: len(datagram), PayloadLength: 3}, }, }, eventRecorder.Events(qlog.DatagramParsed{}), ) eventRecorder.Clear() // don't use stream 0, since that makes it hard to test that the quarter stream ID is used str0, err := clientConn.OpenStreamSync(context.Background()) require.NoError(t, err) str0.Close() str, err := clientConn.OpenStream() require.NoError(t, err) require.Equal(t, quic.StreamID(strID), str.StreamID()) datagramStr := conn.TrackStream(str) // now open the stream... require.NoError(t, serverConn.SendDatagram(append(quarterStreamID, []byte("bar")...))) data, err := datagramStr.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, []byte("bar"), data) // now send a datagram require.NoError(t, datagramStr.SendDatagram([]byte("foobaz"))) expected := quicvarint.Append([]byte{}, strID/4) expected = append(expected, []byte("foobaz")...) require.Equal(t, []qlogwriter.Event{ qlog.DatagramCreated{ QuaterStreamID: strID / 4, Raw: qlog.RawInfo{PayloadLength: 6, Length: len(expected)}, }, }, eventRecorder.Events(qlog.DatagramCreated{}), ) eventRecorder.Clear() data, err = serverConn.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, expected, data) } func TestConnDatagramFailures(t *testing.T) { t.Run("invalid varint", func(t *testing.T) { testConnDatagramFailures(t, []byte{128}) }) t.Run("invalid quarter stream ID", func(t *testing.T) { testConnDatagramFailures(t, quicvarint.Append([]byte{}, maxQuarterStreamID+1)) }) } func testConnDatagramFailures(t *testing.T, datagram []byte) { localConn, peerConn := newConnPair(t, withDatagrams()) conn := newRawConn(localConn, true, nil, nil, nil, nil) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) controlStr, err := peerConn.OpenUniStream() require.NoError(t, err) _, err = controlStr.Write(b) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() localStr, err := localConn.AcceptUniStream(ctx) require.NoError(t, err) go conn.handleUnidirectionalStream(localStr, false) // Wait for SETTINGS to be received and datagram handling to start select { case <-conn.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("timeout waiting for settings") } require.NoError(t, peerConn.SendDatagram(datagram)) select { case <-peerConn.Context().Done(): require.ErrorIs(t, context.Cause(peerConn.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeDatagramError)}, ) case <-time.After(time.Second): t.Fatal("timeout waiting for close") } } quic-go-0.59.0/http3/error.go000066400000000000000000000023421513066070600157000ustar00rootroot00000000000000package http3 import ( "errors" "fmt" "github.com/quic-go/quic-go" ) // Error is returned from the round tripper (for HTTP clients) // and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs. // See section 8 of RFC 9114. type Error struct { Remote bool ErrorCode ErrCode ErrorMessage string } var _ error = &Error{} func (e *Error) Error() string { s := e.ErrorCode.string() if s == "" { s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode)) } // Usually errors are remote. Only make it explicit for local errors. if !e.Remote { s += " (local)" } if e.ErrorMessage != "" { s += ": " + e.ErrorMessage } return s } func (e *Error) Is(target error) bool { t, ok := target.(*Error) return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote } func maybeReplaceError(err error) error { if err == nil { return nil } var ( e Error strErr *quic.StreamError appErr *quic.ApplicationError ) switch { default: return err case errors.As(err, &strErr): e.Remote = strErr.Remote e.ErrorCode = ErrCode(strErr.ErrorCode) case errors.As(err, &appErr): e.Remote = appErr.Remote e.ErrorCode = ErrCode(appErr.ErrorCode) e.ErrorMessage = appErr.ErrorMessage } return &e } quic-go-0.59.0/http3/error_codes.go000066400000000000000000000045421513066070600170610ustar00rootroot00000000000000package http3 import ( "fmt" "github.com/quic-go/quic-go" ) type ErrCode quic.ApplicationErrorCode const ( ErrCodeNoError ErrCode = 0x100 ErrCodeGeneralProtocolError ErrCode = 0x101 ErrCodeInternalError ErrCode = 0x102 ErrCodeStreamCreationError ErrCode = 0x103 ErrCodeClosedCriticalStream ErrCode = 0x104 ErrCodeFrameUnexpected ErrCode = 0x105 ErrCodeFrameError ErrCode = 0x106 ErrCodeExcessiveLoad ErrCode = 0x107 ErrCodeIDError ErrCode = 0x108 ErrCodeSettingsError ErrCode = 0x109 ErrCodeMissingSettings ErrCode = 0x10a ErrCodeRequestRejected ErrCode = 0x10b ErrCodeRequestCanceled ErrCode = 0x10c ErrCodeRequestIncomplete ErrCode = 0x10d ErrCodeMessageError ErrCode = 0x10e ErrCodeConnectError ErrCode = 0x10f ErrCodeVersionFallback ErrCode = 0x110 ErrCodeDatagramError ErrCode = 0x33 ErrCodeQPACKDecompressionFailed ErrCode = 0x200 ) func (e ErrCode) String() string { s := e.string() if s != "" { return s } return fmt.Sprintf("unknown error code: %#x", uint16(e)) } func (e ErrCode) string() string { switch e { case ErrCodeNoError: return "H3_NO_ERROR" case ErrCodeGeneralProtocolError: return "H3_GENERAL_PROTOCOL_ERROR" case ErrCodeInternalError: return "H3_INTERNAL_ERROR" case ErrCodeStreamCreationError: return "H3_STREAM_CREATION_ERROR" case ErrCodeClosedCriticalStream: return "H3_CLOSED_CRITICAL_STREAM" case ErrCodeFrameUnexpected: return "H3_FRAME_UNEXPECTED" case ErrCodeFrameError: return "H3_FRAME_ERROR" case ErrCodeExcessiveLoad: return "H3_EXCESSIVE_LOAD" case ErrCodeIDError: return "H3_ID_ERROR" case ErrCodeSettingsError: return "H3_SETTINGS_ERROR" case ErrCodeMissingSettings: return "H3_MISSING_SETTINGS" case ErrCodeRequestRejected: return "H3_REQUEST_REJECTED" case ErrCodeRequestCanceled: return "H3_REQUEST_CANCELLED" case ErrCodeRequestIncomplete: return "H3_INCOMPLETE_REQUEST" case ErrCodeMessageError: return "H3_MESSAGE_ERROR" case ErrCodeConnectError: return "H3_CONNECT_ERROR" case ErrCodeVersionFallback: return "H3_VERSION_FALLBACK" case ErrCodeDatagramError: return "H3_DATAGRAM_ERROR" case ErrCodeQPACKDecompressionFailed: return "QPACK_DECOMPRESSION_FAILED" default: return "" } } quic-go-0.59.0/http3/error_codes_test.go000066400000000000000000000017701513066070600201200ustar00rootroot00000000000000package http3 import ( "go/ast" "go/parser" "go/token" "path" "runtime" "strconv" "testing" "github.com/stretchr/testify/require" ) func TestErrorCodes(t *testing.T) { // We parse the error code file, extract all constants, and verify that // each of them has a string version. Go FTW! _, thisfile, _, ok := runtime.Caller(0) require.True(t, ok, "Failed to get current frame") filename := path.Join(path.Dir(thisfile), "error_codes.go") fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) require.NoError(t, err) constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs require.Greater(t, len(constSpecs), 4) // at time of writing for _, c := range constSpecs { valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value val, err := strconv.ParseInt(valString, 0, 64) require.NoError(t, err) require.NotEqual(t, "unknown error code", ErrCode(val).String()) } // Test unknown error code require.Equal(t, "unknown error code: 0x1337", ErrCode(0x1337).String()) } quic-go-0.59.0/http3/error_test.go000066400000000000000000000042301513066070600167350ustar00rootroot00000000000000package http3 import ( "testing" "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestErrorConversion(t *testing.T) { tests := []struct { name string input error expected error }{ {name: "nil error", input: nil, expected: nil}, {name: "regular error", input: assert.AnError, expected: assert.AnError}, { name: "stream error", input: &quic.StreamError{ErrorCode: 1337, Remote: true}, expected: &Error{Remote: true, ErrorCode: 1337}, }, { name: "application error", input: &quic.ApplicationError{ErrorCode: 42, Remote: true, ErrorMessage: "foobar"}, expected: &Error{Remote: true, ErrorCode: 42, ErrorMessage: "foobar"}, }, { name: "transport error", input: &quic.TransportError{ErrorCode: 42, Remote: true, ErrorMessage: "foobar"}, expected: &quic.TransportError{ErrorCode: 42, Remote: true, ErrorMessage: "foobar"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := maybeReplaceError(tt.input) if tt.expected == nil { require.Nil(t, result) } else { require.ErrorIs(t, tt.expected, result) } }) } } func TestErrorString(t *testing.T) { tests := []struct { name string err *Error expected string }{ { name: "remote error", err: &Error{ErrorCode: 0x10c, Remote: true}, expected: "H3_REQUEST_CANCELLED", }, { name: "remote error with message", err: &Error{ErrorCode: 0x10c, Remote: true, ErrorMessage: "foobar"}, expected: "H3_REQUEST_CANCELLED: foobar", }, { name: "local error", err: &Error{ErrorCode: 0x10c, Remote: false}, expected: "H3_REQUEST_CANCELLED (local)", }, { name: "local error with message", err: &Error{ErrorCode: 0x10c, Remote: false, ErrorMessage: "foobar"}, expected: "H3_REQUEST_CANCELLED (local): foobar", }, { name: "unknown error code", err: &Error{ErrorCode: 0x1337, Remote: true}, expected: "H3 error (0x1337)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.Equal(t, tt.expected, tt.err.Error()) }) } } quic-go-0.59.0/http3/frames.go000066400000000000000000000211161513066070600160240ustar00rootroot00000000000000package http3 import ( "bytes" "errors" "fmt" "io" "maps" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" ) // FrameType is the frame type of a HTTP/3 frame type FrameType uint64 type frame any // The maximum length of an encoded HTTP/3 frame header is 16: // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) const frameHeaderLen = 16 type countingByteReader struct { quicvarint.Reader NumRead int } func (r *countingByteReader) ReadByte() (byte, error) { b, err := r.Reader.ReadByte() if err == nil { r.NumRead++ } return b, err } func (r *countingByteReader) Read(b []byte) (int, error) { n, err := r.Reader.Read(b) r.NumRead += n return n, err } func (r *countingByteReader) Reset() { r.NumRead = 0 } type frameParser struct { r io.Reader streamID quic.StreamID closeConn func(quic.ApplicationErrorCode, string) error } func (p *frameParser) ParseNext(qlogger qlogwriter.Recorder) (frame, error) { r := &countingByteReader{Reader: quicvarint.NewReader(p.r)} for { t, err := quicvarint.Read(r) if err != nil { return nil, err } l, err := quicvarint.Read(r) if err != nil { return nil, err } switch t { case 0x0: // DATA if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: p.streamID, Raw: qlog.RawInfo{ Length: int(l) + r.NumRead, PayloadLength: int(l), }, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }) } return &dataFrame{Length: l}, nil case 0x1: // HEADERS return &headersFrame{ Length: l, headerLen: r.NumRead, }, nil case 0x4: // SETTINGS return parseSettingsFrame(r, l, p.streamID, qlogger) case 0x3: // unsupported: CANCEL_PUSH if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: p.streamID, Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.CancelPushFrame{}}, }) } case 0x5: // unsupported: PUSH_PROMISE if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: p.streamID, Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.PushPromiseFrame{}}, }) } case 0x7: // GOAWAY return parseGoAwayFrame(r, l, p.streamID, qlogger) case 0xd: // unsupported: MAX_PUSH_ID if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: p.streamID, Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.MaxPushIDFrame{}}, }) } case 0x2, 0x6, 0x8, 0x9: // reserved frame types if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: p.streamID, Raw: qlog.RawInfo{Length: r.NumRead + int(l), PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.ReservedFrame{Type: t}}, }) } p.closeConn(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") return nil, fmt.Errorf("http3: reserved frame type: %d", t) default: // unknown frame types if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: p.streamID, Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.UnknownFrame{Type: t}}, }) } } // skip over the payload if _, err := io.CopyN(io.Discard, r, int64(l)); err != nil { return nil, err } r.Reset() } } type dataFrame struct { Length uint64 } func (f *dataFrame) Append(b []byte) []byte { b = quicvarint.Append(b, 0x0) return quicvarint.Append(b, f.Length) } type headersFrame struct { Length uint64 headerLen int // number of bytes read for type and length field } func (f *headersFrame) Append(b []byte) []byte { b = quicvarint.Append(b, 0x1) return quicvarint.Append(b, f.Length) } const ( // SETTINGS_MAX_FIELD_SECTION_SIZE settingMaxFieldSectionSize = 0x6 // Extended CONNECT, RFC 9220 settingExtendedConnect = 0x8 // HTTP Datagrams, RFC 9297 settingDatagram = 0x33 ) type settingsFrame struct { MaxFieldSectionSize int64 // SETTINGS_MAX_FIELD_SECTION_SIZE, -1 if not set Datagram bool // HTTP Datagrams, RFC 9297 ExtendedConnect bool // Extended CONNECT, RFC 9220 Other map[uint64]uint64 // all settings that we don't explicitly recognize } func pointer[T any](v T) *T { return &v } func parseSettingsFrame(r *countingByteReader, l uint64, streamID quic.StreamID, qlogger qlogwriter.Recorder) (*settingsFrame, error) { if l > 8*(1<<10) { return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l) } buf := make([]byte, l) if _, err := io.ReadFull(r, buf); err != nil { if err == io.ErrUnexpectedEOF { return nil, io.EOF } return nil, err } frame := &settingsFrame{MaxFieldSectionSize: -1} b := bytes.NewReader(buf) settingsFrame := qlog.SettingsFrame{MaxFieldSectionSize: -1} var readMaxFieldSectionSize, readDatagram, readExtendedConnect bool for b.Len() > 0 { id, err := quicvarint.Read(b) if err != nil { // should not happen. We allocated the whole frame already. return nil, err } val, err := quicvarint.Read(b) if err != nil { // should not happen. We allocated the whole frame already. return nil, err } switch id { case settingMaxFieldSectionSize: if readMaxFieldSectionSize { return nil, fmt.Errorf("duplicate setting: %d", id) } readMaxFieldSectionSize = true frame.MaxFieldSectionSize = int64(val) settingsFrame.MaxFieldSectionSize = int64(val) case settingExtendedConnect: if readExtendedConnect { return nil, fmt.Errorf("duplicate setting: %d", id) } readExtendedConnect = true if val != 0 && val != 1 { return nil, fmt.Errorf("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: %d", val) } frame.ExtendedConnect = val == 1 if qlogger != nil { settingsFrame.ExtendedConnect = pointer(frame.ExtendedConnect) } case settingDatagram: if readDatagram { return nil, fmt.Errorf("duplicate setting: %d", id) } readDatagram = true if val != 0 && val != 1 { return nil, fmt.Errorf("invalid value for SETTINGS_H3_DATAGRAM: %d", val) } frame.Datagram = val == 1 if qlogger != nil { settingsFrame.Datagram = pointer(frame.Datagram) } default: if _, ok := frame.Other[id]; ok { return nil, fmt.Errorf("duplicate setting: %d", id) } if frame.Other == nil { frame.Other = make(map[uint64]uint64) } frame.Other[id] = val } } if qlogger != nil { settingsFrame.Other = maps.Clone(frame.Other) qlogger.RecordEvent(qlog.FrameParsed{ StreamID: streamID, Raw: qlog.RawInfo{ Length: r.NumRead, PayloadLength: int(l), }, Frame: qlog.Frame{Frame: settingsFrame}, }) } return frame, nil } func (f *settingsFrame) Append(b []byte) []byte { b = quicvarint.Append(b, 0x4) var l int if f.MaxFieldSectionSize >= 0 { l += quicvarint.Len(settingMaxFieldSectionSize) + quicvarint.Len(uint64(f.MaxFieldSectionSize)) } for id, val := range f.Other { l += quicvarint.Len(id) + quicvarint.Len(val) } if f.Datagram { l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) } if f.ExtendedConnect { l += quicvarint.Len(settingExtendedConnect) + quicvarint.Len(1) } b = quicvarint.Append(b, uint64(l)) if f.MaxFieldSectionSize >= 0 { b = quicvarint.Append(b, settingMaxFieldSectionSize) b = quicvarint.Append(b, uint64(f.MaxFieldSectionSize)) } if f.Datagram { b = quicvarint.Append(b, settingDatagram) b = quicvarint.Append(b, 1) } if f.ExtendedConnect { b = quicvarint.Append(b, settingExtendedConnect) b = quicvarint.Append(b, 1) } for id, val := range f.Other { b = quicvarint.Append(b, id) b = quicvarint.Append(b, val) } return b } type goAwayFrame struct { StreamID quic.StreamID } func parseGoAwayFrame(r *countingByteReader, l uint64, streamID quic.StreamID, qlogger qlogwriter.Recorder) (*goAwayFrame, error) { frame := &goAwayFrame{} startLen := r.NumRead id, err := quicvarint.Read(r) if err != nil { return nil, err } if r.NumRead-startLen != int(l) { return nil, errors.New("GOAWAY frame: inconsistent length") } frame.StreamID = quic.StreamID(id) if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: streamID, Raw: qlog.RawInfo{Length: r.NumRead, PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.GoAwayFrame{StreamID: frame.StreamID}}, }) } return frame, nil } func (f *goAwayFrame) Append(b []byte) []byte { b = quicvarint.Append(b, 0x7) b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(f.StreamID)))) return quicvarint.Append(b, uint64(f.StreamID)) } quic-go-0.59.0/http3/frames_test.go000066400000000000000000000270171513066070600170710ustar00rootroot00000000000000package http3 import ( "bytes" "context" "fmt" "io" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func testFrameParserEOF(t *testing.T, data []byte) { t.Helper() for i := range data { b := make([]byte, i) copy(b, data[:i]) fp := frameParser{r: bytes.NewReader(b)} _, err := fp.ParseNext(nil) require.Error(t, err) require.ErrorIs(t, err, io.EOF) } } func TestParserReservedFrameType(t *testing.T) { for _, ft := range []uint64{0x2, 0x6, 0x8, 0x9} { t.Run(fmt.Sprintf("type %#x", ft), func(t *testing.T) { var eventRecorder events.Recorder client, server := newConnPair(t, withDatagrams(), withServerRecorder(&eventRecorder)) data := quicvarint.Append(nil, ft) data = quicvarint.Append(data, 6) data = append(data, []byte("foobar")...) fp := frameParser{ streamID: 42, r: bytes.NewReader(data), closeConn: client.CloseWithError, } _, err := fp.ParseNext(&eventRecorder) require.Error(t, err) require.ErrorContains(t, err, "http3: reserved frame type") select { case <-server.Context().Done(): require.ErrorIs(t, context.Cause(server.Context()), &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeFrameUnexpected)}, ) case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, []qlogwriter.Event{ qlog.FrameParsed{ StreamID: 42, Raw: qlog.RawInfo{Length: len(data), PayloadLength: 6}, Frame: qlog.Frame{Frame: qlog.ReservedFrame{Type: ft}}, }, }, eventRecorder.Events(qlog.FrameParsed{}), ) }) } } func TestParserUnknownFrameType(t *testing.T) { data := quicvarint.Append(nil, 0xdead) data = quicvarint.Append(data, 6) data = append(data, []byte("foobar")...) data = quicvarint.Append(data, 0xbeef) data = quicvarint.Append(data, 3) data = append(data, []byte("baz")...) hf := &headersFrame{Length: 3} data = hf.Append(data) data = append(data, []byte("foo")...) r := bytes.NewReader(data) fp := frameParser{r: r} f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &headersFrame{}, f) hf = f.(*headersFrame) require.Equal(t, uint64(3), hf.Length) payload := make([]byte, 3) _, err = io.ReadFull(r, payload) require.NoError(t, err) require.Equal(t, []byte("foo"), payload) } func TestParserUnsupportedFrameTypes(t *testing.T) { for _, tc := range []struct { name string ft uint64 qf any }{ {name: "CANCEL_PUSH", ft: 0x3, qf: qlog.CancelPushFrame{}}, {name: "PUSH_PROMISE", ft: 0x5, qf: qlog.PushPromiseFrame{}}, {name: "MAX_PUSH_ID", ft: 0xd, qf: qlog.MaxPushIDFrame{}}, } { t.Run(tc.name, func(t *testing.T) { var eventRecorder events.Recorder data := quicvarint.Append(nil, tc.ft) data = quicvarint.Append(data, 6) data = append(data, []byte("foobar")...) df := &dataFrame{Length: 3} data = df.Append(data) data = append(data, []byte("foo")...) r := bytes.NewReader(data) fp := frameParser{streamID: 42, r: r} f, err := fp.ParseNext(&eventRecorder) require.NoError(t, err) require.IsType(t, &dataFrame{}, f) df = f.(*dataFrame) require.Equal(t, uint64(3), df.Length) payload := make([]byte, 3) _, err = io.ReadFull(r, payload) require.NoError(t, err) require.Equal(t, []byte("foo"), payload) headerLen := quicvarint.Len(tc.ft) + quicvarint.Len(6) dfLen, _ := expectedFrameLength(t, df) require.Equal(t, []qlogwriter.Event{ qlog.FrameParsed{ StreamID: 42, Raw: qlog.RawInfo{Length: headerLen, PayloadLength: 6}, Frame: qlog.Frame{Frame: tc.qf}, }, qlog.FrameParsed{ StreamID: 42, Raw: qlog.RawInfo{Length: dfLen, PayloadLength: 3}, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }, }, eventRecorder.Events(qlog.FrameParsed{}), ) }) } } func TestParserHeadersFrame(t *testing.T) { data := quicvarint.Append(nil, 1) // type byte data = quicvarint.Append(data, 0x1337) fp := frameParser{r: bytes.NewReader(data)} // incomplete data results in an io.EOF testFrameParserEOF(t, data) // parse f1, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &headersFrame{}, f1) require.Equal(t, uint64(0x1337), f1.(*headersFrame).Length) // write and parse fp = frameParser{r: bytes.NewReader(f1.(*headersFrame).Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, f1, f2) } func TestDataFrame(t *testing.T) { data := quicvarint.Append(nil, 0) // type byte data = quicvarint.Append(data, 0x1337) fp := frameParser{r: bytes.NewReader(data)} // incomplete data results in an io.EOF testFrameParserEOF(t, data) // parse f1, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &dataFrame{}, f1) require.Equal(t, uint64(0x1337), f1.(*dataFrame).Length) // write and parse fp = frameParser{r: bytes.NewReader(f1.(*dataFrame).Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, f1, f2) } func appendSetting(b []byte, key, value uint64) []byte { b = quicvarint.Append(b, key) b = quicvarint.Append(b, value) return b } func TestParserSettingsFrame(t *testing.T) { settings := appendSetting(nil, 13, 37) settings = appendSetting(settings, 0xdead, 0xbeef) data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) // incomplete data results in an io.EOF testFrameParserEOF(t, data) fp := frameParser{r: bytes.NewReader(data)} frame, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, frame) sf := frame.(*settingsFrame) require.Len(t, sf.Other, 2) require.Equal(t, uint64(37), sf.Other[uint64(13)]) require.Equal(t, uint64(0xbeef), sf.Other[uint64(0xdead)]) // write and parse fp = frameParser{r: bytes.NewReader(sf.Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f2) sf2 := f2.(*settingsFrame) require.Len(t, sf2.Other, len(sf.Other)) require.Equal(t, sf.Other, sf2.Other) } func TestParserSettingsFrameDuplicateSettings(t *testing.T) { for _, tc := range []struct { name string num uint64 val uint64 }{ { name: "other setting", num: 13, val: 37, }, { name: "extended connect", num: settingExtendedConnect, val: 1, }, { name: "max field section size", num: settingMaxFieldSectionSize, val: 1337, }, { name: "datagram", num: settingDatagram, val: 1, }, } { t.Run(tc.name, func(t *testing.T) { settings := appendSetting(nil, tc.num, tc.val) settings = appendSetting(settings, tc.num, tc.val) data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) fp := frameParser{r: bytes.NewReader(data)} _, err := fp.ParseNext(nil) require.Error(t, err) require.EqualError(t, err, fmt.Sprintf("duplicate setting: %d", tc.num)) }) } } func TestParserSettingsFrameMaxFieldSectionSize(t *testing.T) { t.Run("absent", func(t *testing.T) { testParserSettingsFrameMaxFieldSectionSize(t, false) }) t.Run("with value", func(t *testing.T) { testParserSettingsFrameMaxFieldSectionSize(t, true) }) } func testParserSettingsFrameMaxFieldSectionSize(t *testing.T, present bool) { var settings []byte if present { settings = appendSetting(nil, settingMaxFieldSectionSize, 1337) } data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) fp := frameParser{r: bytes.NewReader(data)} f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f) sf := f.(*settingsFrame) if present { require.EqualValues(t, 1337, sf.MaxFieldSectionSize) } else { require.EqualValues(t, -1, sf.MaxFieldSectionSize) } fp = frameParser{r: bytes.NewReader(sf.Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, sf, f2) } func TestParserSettingsFrameDatagram(t *testing.T) { t.Run("enabled", func(t *testing.T) { testParserSettingsFrameDatagram(t, true) }) t.Run("disabled", func(t *testing.T) { testParserSettingsFrameDatagram(t, false) }) } func testParserSettingsFrameDatagram(t *testing.T, enabled bool) { var settings []byte switch enabled { case true: settings = appendSetting(nil, settingDatagram, 1) case false: settings = appendSetting(nil, settingDatagram, 0) } data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) fp := frameParser{r: bytes.NewReader(data)} f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f) sf := f.(*settingsFrame) require.Equal(t, enabled, sf.Datagram) fp = frameParser{r: bytes.NewReader(sf.Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, sf, f2) } func TestParserSettingsFrameDatagramInvalidValue(t *testing.T) { settings := quicvarint.Append(nil, settingDatagram) settings = quicvarint.Append(settings, 1337) data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) fp := frameParser{r: bytes.NewReader(data)} _, err := fp.ParseNext(nil) require.EqualError(t, err, "invalid value for SETTINGS_H3_DATAGRAM: 1337") } func TestParserSettingsFrameExtendedConnect(t *testing.T) { t.Run("enabled", func(t *testing.T) { testParserSettingsFrameExtendedConnect(t, true) }) t.Run("disabled", func(t *testing.T) { testParserSettingsFrameExtendedConnect(t, false) }) } func testParserSettingsFrameExtendedConnect(t *testing.T, enabled bool) { var settings []byte switch enabled { case true: settings = appendSetting(nil, settingExtendedConnect, 1) case false: settings = appendSetting(nil, settingExtendedConnect, 0) } data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) fp := frameParser{r: bytes.NewReader(data)} f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f) sf := f.(*settingsFrame) require.Equal(t, enabled, sf.ExtendedConnect) fp = frameParser{r: bytes.NewReader(sf.Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, sf, f2) } func TestParserSettingsFrameExtendedConnectInvalidValue(t *testing.T) { settings := quicvarint.Append(nil, settingExtendedConnect) settings = quicvarint.Append(settings, 1337) data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) fp := frameParser{r: bytes.NewReader(data)} _, err := fp.ParseNext(nil) require.EqualError(t, err, "invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: 1337") } func TestParserGoAwayFrame(t *testing.T) { data := quicvarint.Append(nil, 7) // type byte data = quicvarint.Append(data, uint64(quicvarint.Len(100))) data = quicvarint.Append(data, 100) // incomplete data results in an io.EOF testFrameParserEOF(t, data) fp := frameParser{r: bytes.NewReader(data)} f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &goAwayFrame{}, f) require.Equal(t, quic.StreamID(100), f.(*goAwayFrame).StreamID) // write and parse fp = frameParser{r: bytes.NewReader(f.(*goAwayFrame).Append(nil))} f2, err := fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, f, f2) } quic-go-0.59.0/http3/gzip_reader.go000066400000000000000000000014361513066070600170450ustar00rootroot00000000000000package http3 // copied from net/transport.go // gzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read import ( "compress/gzip" "io" ) // call gzip.NewReader on the first call to Read type gzipReader struct { body io.ReadCloser // underlying Response.Body zr *gzip.Reader // lazily-initialized gzip reader zerr error // sticky error } func newGzipReader(body io.ReadCloser) io.ReadCloser { return &gzipReader{body: body} } func (gz *gzipReader) Read(p []byte) (n int, err error) { if gz.zerr != nil { return 0, gz.zerr } if gz.zr == nil { gz.zr, err = gzip.NewReader(gz.body) if err != nil { gz.zerr = err return 0, err } } return gz.zr.Read(p) } func (gz *gzipReader) Close() error { return gz.body.Close() } quic-go-0.59.0/http3/headers.go000066400000000000000000000264421513066070600161710ustar00rootroot00000000000000package http3 import ( "bytes" "errors" "fmt" "io" "net/http" "net/textproto" "net/url" "strconv" "strings" "golang.org/x/net/http/httpguts" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" ) type qpackError struct{ err error } func (e *qpackError) Error() string { return fmt.Sprintf("qpack: %v", e.err) } func (e *qpackError) Unwrap() error { return e.err } var errHeaderTooLarge = errors.New("http3: headers too large") type header struct { // Pseudo header fields defined in RFC 9114 Path string Method string Authority string Scheme string Status string // for Extended connect Protocol string // parsed and deduplicated. -1 if no Content-Length header is sent ContentLength int64 // all non-pseudo headers Headers http.Header } // connection-specific header fields must not be sent on HTTP/3 var invalidHeaderFields = [...]string{ "connection", "keep-alive", "proxy-connection", "transfer-encoding", "upgrade", } func parseHeaders(decodeFn qpack.DecodeFunc, isRequest bool, sizeLimit int, headerFields *[]qpack.HeaderField) (header, error) { hdr := header{Headers: make(http.Header)} var readFirstRegularHeader, readContentLength bool var contentLengthStr string for { h, err := decodeFn() if err != nil { if err == io.EOF { break } return header{}, &qpackError{err} } if headerFields != nil { *headerFields = append(*headerFields, h) } // RFC 9114, section 4.2.2: // The size of a field list is calculated based on the uncompressed size of fields, // including the length of the name and value in bytes plus an overhead of 32 bytes for each field. sizeLimit -= len(h.Name) + len(h.Value) + 32 if sizeLimit < 0 { return header{}, errHeaderTooLarge } // field names need to be lowercase, see section 4.2 of RFC 9114 if strings.ToLower(h.Name) != h.Name { return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) } if !httpguts.ValidHeaderFieldValue(h.Value) { return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) } if h.IsPseudo() { if readFirstRegularHeader { // all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114 return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) } var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses var isDuplicatePseudoHeader bool // pseudo headers are allowed to appear exactly once switch h.Name { case ":path": isDuplicatePseudoHeader = hdr.Path != "" hdr.Path = h.Value case ":method": isDuplicatePseudoHeader = hdr.Method != "" hdr.Method = h.Value case ":authority": isDuplicatePseudoHeader = hdr.Authority != "" hdr.Authority = h.Value case ":protocol": isDuplicatePseudoHeader = hdr.Protocol != "" hdr.Protocol = h.Value case ":scheme": isDuplicatePseudoHeader = hdr.Scheme != "" hdr.Scheme = h.Value case ":status": isDuplicatePseudoHeader = hdr.Status != "" hdr.Status = h.Value isResponsePseudoHeader = true default: return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name) } if isDuplicatePseudoHeader { return header{}, fmt.Errorf("duplicate pseudo header: %s", h.Name) } if isRequest && isResponsePseudoHeader { return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) } if !isRequest && !isResponsePseudoHeader { return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) } } else { if !httpguts.ValidHeaderFieldName(h.Name) { return header{}, fmt.Errorf("invalid header field name: %q", h.Name) } for _, invalidField := range invalidHeaderFields { if h.Name == invalidField { return header{}, fmt.Errorf("invalid header field name: %q", h.Name) } } if h.Name == "te" && h.Value != "trailers" { return header{}, fmt.Errorf("invalid TE header field value: %q", h.Value) } readFirstRegularHeader = true switch h.Name { case "content-length": // Ignore duplicate Content-Length headers. // Fail if the duplicates differ. if !readContentLength { readContentLength = true contentLengthStr = h.Value } else if contentLengthStr != h.Value { return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value) } default: hdr.Headers.Add(h.Name, h.Value) } } } hdr.ContentLength = -1 if len(contentLengthStr) > 0 { // use ParseUint instead of ParseInt, so that parsing fails on negative values cl, err := strconv.ParseUint(contentLengthStr, 10, 63) if err != nil { return header{}, fmt.Errorf("invalid content length: %w", err) } hdr.Headers.Set("Content-Length", contentLengthStr) hdr.ContentLength = int64(cl) } return hdr, nil } func parseTrailers(decodeFn qpack.DecodeFunc, headerFields *[]qpack.HeaderField) (http.Header, error) { h := make(http.Header) for { hf, err := decodeFn() if err != nil { if err == io.EOF { break } return nil, &qpackError{err} } if headerFields != nil { *headerFields = append(*headerFields, hf) } if hf.IsPseudo() { return nil, fmt.Errorf("http3: received pseudo header in trailer: %s", hf.Name) } h.Add(hf.Name, hf.Value) } return h, nil } func requestFromHeaders(decodeFn qpack.DecodeFunc, sizeLimit int, headerFields *[]qpack.HeaderField) (*http.Request, error) { hdr, err := parseHeaders(decodeFn, true, sizeLimit, headerFields) if err != nil { return nil, err } // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 if len(hdr.Headers["Cookie"]) > 0 { hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; ")) } isConnect := hdr.Method == http.MethodConnect // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 isExtendedConnected := isConnect && hdr.Protocol != "" if isExtendedConnected { if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" { return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") } } else if isConnect { if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT return nil, errors.New(":path must be empty and :authority must not be empty") } } else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 { return nil, errors.New(":path, :authority and :method must not be empty") } if !isExtendedConnected && len(hdr.Protocol) > 0 { return nil, errors.New(":protocol must be empty") } var u *url.URL var requestURI string protocol := "HTTP/3.0" if isConnect { u = &url.URL{} if isExtendedConnected { u, err = url.ParseRequestURI(hdr.Path) if err != nil { return nil, err } protocol = hdr.Protocol } else { u.Path = hdr.Path } u.Scheme = hdr.Scheme u.Host = hdr.Authority requestURI = hdr.Authority } else { u, err = url.ParseRequestURI(hdr.Path) if err != nil { return nil, fmt.Errorf("invalid content length: %w", err) } requestURI = hdr.Path } req := &http.Request{ Method: hdr.Method, URL: u, Proto: protocol, ProtoMajor: 3, ProtoMinor: 0, Header: hdr.Headers, Body: nil, ContentLength: hdr.ContentLength, Host: hdr.Authority, RequestURI: requestURI, } req.Trailer = extractAnnouncedTrailers(req.Header) return req, nil } // updateResponseFromHeaders sets up http.Response as an HTTP/3 response, // using the decoded qpack header filed. // It is only called for the HTTP header (and not the HTTP trailer). // It takes an http.Response as an argument to allow the caller to set the trailer later on. func updateResponseFromHeaders(rsp *http.Response, decodeFn qpack.DecodeFunc, sizeLimit int, headerFields *[]qpack.HeaderField) error { hdr, err := parseHeaders(decodeFn, false, sizeLimit, headerFields) if err != nil { return err } if hdr.Status == "" { return errors.New("missing :status field") } rsp.Proto = "HTTP/3.0" rsp.ProtoMajor = 3 rsp.Header = hdr.Headers rsp.Trailer = extractAnnouncedTrailers(rsp.Header) rsp.ContentLength = hdr.ContentLength status, err := strconv.Atoi(hdr.Status) if err != nil { return fmt.Errorf("invalid status code: %w", err) } rsp.StatusCode = status rsp.Status = hdr.Status + " " + http.StatusText(status) return nil } // extractAnnouncedTrailers extracts trailer keys from the "Trailer" header. // It returns a map with the announced keys set to nil values, and removes the "Trailer" header. // It handles both duplicate as well as comma-separated values for the Trailer header. // For example: // // Trailer: Trailer1, Trailer2 // Trailer: Trailer3 // // Will result in a map containing the keys "Trailer1", "Trailer2", "Trailer3" with nil values. func extractAnnouncedTrailers(header http.Header) http.Header { rawTrailers, ok := header["Trailer"] if !ok { return nil } trailers := make(http.Header) for _, rawVal := range rawTrailers { for _, val := range strings.Split(rawVal, ",") { trailers[http.CanonicalHeaderKey(textproto.TrimString(val))] = nil } } delete(header, "Trailer") return trailers } // writeTrailers encodes and writes HTTP trailers as a HEADERS frame. // It returns true if trailers were written, false if there were no trailers to write. func writeTrailers(wr io.Writer, trailers http.Header, streamID quic.StreamID, qlogger qlogwriter.Recorder) (bool, error) { var hasValues bool for k, vals := range trailers { if httpguts.ValidTrailerHeader(k) && len(vals) > 0 { hasValues = true break } } if !hasValues { return false, nil } var buf bytes.Buffer enc := qpack.NewEncoder(&buf) var headerFields []qlog.HeaderField if qlogger != nil { headerFields = make([]qlog.HeaderField, 0, len(trailers)) } for k, vals := range trailers { if len(vals) == 0 { continue } if !httpguts.ValidTrailerHeader(k) { continue } lowercaseKey := strings.ToLower(k) for _, v := range vals { if err := enc.WriteField(qpack.HeaderField{Name: lowercaseKey, Value: v}); err != nil { return false, err } if qlogger != nil { headerFields = append(headerFields, qlog.HeaderField{Name: lowercaseKey, Value: v}) } } } b := make([]byte, 0, frameHeaderLen+buf.Len()) b = (&headersFrame{Length: uint64(buf.Len())}).Append(b) b = append(b, buf.Bytes()...) if qlogger != nil { qlogCreatedHeadersFrame(qlogger, streamID, len(b), buf.Len(), headerFields) } _, err := wr.Write(b) return true, err } func decodeTrailers(r io.Reader, hf *headersFrame, maxHeaderBytes int, decoder *qpack.Decoder, qlogger qlogwriter.Recorder, streamID quic.StreamID) (http.Header, error) { if hf.Length > uint64(maxHeaderBytes) { maybeQlogInvalidHeadersFrame(qlogger, streamID, hf.Length) return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, maxHeaderBytes) } b := make([]byte, hf.Length) if _, err := io.ReadFull(r, b); err != nil { return nil, err } decodeFn := decoder.Decode(b) var fields []qpack.HeaderField if qlogger != nil { fields = make([]qpack.HeaderField, 0, 16) } trailers, err := parseTrailers(decodeFn, &fields) if err != nil { maybeQlogInvalidHeadersFrame(qlogger, streamID, hf.Length) return nil, err } if qlogger != nil { qlogParsedHeadersFrame(qlogger, streamID, hf, fields) } return trailers, nil } quic-go-0.59.0/http3/headers_test.go000066400000000000000000000366661513066070600172410ustar00rootroot00000000000000package http3 import ( "bytes" "fmt" "io" "math" "net/http" "testing" "github.com/quic-go/qpack" "github.com/stretchr/testify/require" ) func decodeFromSlice(headers []qpack.HeaderField) qpack.DecodeFunc { var i int return func() (qpack.HeaderField, error) { if i >= len(headers) { return qpack.HeaderField{}, io.EOF } h := headers[i] i++ return h, nil } } func TestRequestHeaderParsing(t *testing.T) { t.Run("regular path", func(t *testing.T) { testRequestHeaderParsing(t, "/foo") }) // see https://github.com/quic-go/quic-go/pull/1898 t.Run("path starting with //", func(t *testing.T) { testRequestHeaderParsing(t, "//foo") }) } func testRequestHeaderParsing(t *testing.T, path string) { headers := []qpack.HeaderField{ {Name: ":path", Value: path}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, {Name: "content-length", Value: "42"}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, http.MethodGet, req.Method) require.Equal(t, path, req.URL.Path) require.Equal(t, "", req.URL.Host) require.Equal(t, "HTTP/3.0", req.Proto) require.Equal(t, 3, req.ProtoMajor) require.Zero(t, req.ProtoMinor) require.Equal(t, int64(42), req.ContentLength) require.Equal(t, 1, len(req.Header)) require.Equal(t, "42", req.Header.Get("Content-Length")) require.Nil(t, req.Body) require.Equal(t, "quic-go.net", req.Host) require.Equal(t, path, req.RequestURI) } func TestRequestHeadersContentLength(t *testing.T) { t.Run("no content length", func(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":path", Value: "/"}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, int64(-1), req.ContentLength) }) t.Run("multiple content lengths", func(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":path", Value: "/"}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, {Name: "content-length", Value: "42"}, {Name: "content-length", Value: "42"}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, "42", req.Header.Get("Content-Length")) }) } func TestRequestHeadersContentLengthValidation(t *testing.T) { for _, tc := range []struct { name string headers []qpack.HeaderField err string errContains string }{ { name: "negative content length", headers: []qpack.HeaderField{ {Name: "content-length", Value: "-42"}, }, errContains: "invalid content length", }, { name: "multiple differing content lengths", headers: []qpack.HeaderField{ {Name: "content-length", Value: "42"}, {Name: "content-length", Value: "1337"}, }, err: "contradicting content lengths (42 and 1337)", }, } { t.Run(tc.name, func(t *testing.T) { _, err := requestFromHeaders(decodeFromSlice(tc.headers), math.MaxInt, nil) if tc.errContains != "" { require.ErrorContains(t, err, tc.errContains) } if tc.err != "" { require.EqualError(t, err, tc.err) } }) } } func TestRequestHeadersValidation(t *testing.T) { for _, tc := range []struct { name string headers []qpack.HeaderField err string }{ { name: "upper-case field name", headers: []qpack.HeaderField{ {Name: "Content-Length", Value: "42"}, }, err: "header field is not lower-case: Content-Length", }, { name: "unknown pseudo header", headers: []qpack.HeaderField{ {Name: ":foo", Value: "bar"}, }, err: "unknown pseudo header: :foo", }, { name: "pseudo header after regular header", headers: []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: "content-length", Value: "42"}, {Name: ":authority", Value: "quic-go.net"}, }, err: "received pseudo header :authority after a regular header field", }, { name: "invalid field name", headers: []qpack.HeaderField{ {Name: "@", Value: "42"}, }, err: `invalid header field name: "@"`, }, { name: "invalid field value", headers: []qpack.HeaderField{ {Name: "content", Value: "\n"}, }, err: `invalid header field value for content: "\n"`, }, { name: ":status header field", // :status is a response pseudo header headers: []qpack.HeaderField{ {Name: ":status", Value: "404"}, }, err: "invalid request pseudo header: :status", }, { name: "missing :path", headers: []qpack.HeaderField{ {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, }, err: ":path, :authority and :method must not be empty", }, { name: "missing :authority", headers: []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":method", Value: http.MethodGet}, }, err: ":path, :authority and :method must not be empty", }, { name: "missing :method", headers: []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":authority", Value: "quic-go.net"}, }, err: ":path, :authority and :method must not be empty", }, { name: "duplicate :path", headers: []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":path", Value: "/foo"}, }, err: "duplicate pseudo header: :path", }, { name: "duplicate :authority", headers: []qpack.HeaderField{ {Name: ":authority", Value: "quic-go.net"}, {Name: ":authority", Value: "quic-go.net"}, }, err: "duplicate pseudo header: :authority", }, { name: "duplicate :method", headers: []qpack.HeaderField{ {Name: ":method", Value: http.MethodGet}, {Name: ":method", Value: http.MethodGet}, }, err: "duplicate pseudo header: :method", }, { name: "invalid :protocol", headers: []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, {Name: ":protocol", Value: "connect-udp"}, }, err: ":protocol must be empty", }, } { t.Run(tc.name, func(t *testing.T) { _, err := requestFromHeaders(decodeFromSlice(tc.headers), math.MaxInt, nil) require.EqualError(t, err, tc.err) require.NotErrorAs(t, err, new(*qpackError)) }) } } func TestCookieHeader(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, {Name: "cookie", Value: "cookie1=foobar1"}, {Name: "cookie", Value: "cookie2=foobar2"}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, http.Header{ "Cookie": []string{"cookie1=foobar1; cookie2=foobar2"}, }, req.Header) } func TestHeadersConcatenation(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodGet}, {Name: "cache-control", Value: "max-age=0"}, {Name: "duplicate-header", Value: "1"}, {Name: "duplicate-header", Value: "2"}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, http.Header{ "Cache-Control": []string{"max-age=0"}, "Duplicate-Header": []string{"1", "2"}, }, req.Header) } func TestRequestHeadersConnect(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodConnect}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, http.MethodConnect, req.Method) require.Equal(t, "HTTP/3.0", req.Proto) require.Equal(t, "quic-go.net", req.RequestURI) } func TestRequestHeadersConnectValidation(t *testing.T) { for _, tc := range []struct { name string headers []qpack.HeaderField err string }{ { name: "missing :authority", headers: []qpack.HeaderField{ {Name: ":method", Value: http.MethodConnect}, }, err: ":path must be empty and :authority must not be empty", }, { name: ":path set", headers: []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":method", Value: http.MethodConnect}, }, err: ":path must be empty and :authority must not be empty", }, } { t.Run(tc.name, func(t *testing.T) { _, err := requestFromHeaders(decodeFromSlice(tc.headers), math.MaxInt, nil) require.EqualError(t, err, tc.err) }) } } func TestRequestHeadersExtendedConnect(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":protocol", Value: "webtransport"}, {Name: ":scheme", Value: "ftp"}, {Name: ":method", Value: http.MethodConnect}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":path", Value: "/foo?val=1337"}, } req, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.NoError(t, err) require.Equal(t, http.MethodConnect, req.Method) require.Equal(t, "webtransport", req.Proto) require.Equal(t, "ftp://quic-go.net/foo?val=1337", req.URL.String()) require.Equal(t, "1337", req.URL.Query().Get("val")) } func TestRequestHeadersExtendedConnectRequestValidation(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":protocol", Value: "webtransport"}, {Name: ":method", Value: http.MethodConnect}, {Name: ":authority", Value: "quic.clemente.io"}, {Name: ":path", Value: "/foo"}, } _, err := requestFromHeaders(decodeFromSlice(headers), math.MaxInt, nil) require.EqualError(t, err, "extended CONNECT: :scheme, :path and :authority must not be empty") } func TestResponseHeaderParsing(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":status", Value: "200"}, {Name: "content-length", Value: "42"}, } rsp := &http.Response{} require.NoError(t, updateResponseFromHeaders(rsp, decodeFromSlice(headers), math.MaxInt, nil)) require.Equal(t, "HTTP/3.0", rsp.Proto) require.Equal(t, 3, rsp.ProtoMajor) require.Zero(t, rsp.ProtoMinor) require.Equal(t, int64(42), rsp.ContentLength) require.Equal(t, 1, len(rsp.Header)) require.Equal(t, "42", rsp.Header.Get("Content-Length")) require.Nil(t, rsp.Body) require.Equal(t, 200, rsp.StatusCode) require.Equal(t, "200 OK", rsp.Status) } func TestResponseHeaderParsingValidation(t *testing.T) { for _, tc := range []struct { name string headers []qpack.HeaderField err string errContains string }{ { name: "missing :status", headers: []qpack.HeaderField{ {Name: "content-length", Value: "42"}, }, err: "missing :status field", }, { name: "invalid status code", headers: []qpack.HeaderField{ {Name: ":status", Value: "foobar"}, }, errContains: "invalid status code", }, { name: ":method header field", // :method is a request pseudo header headers: []qpack.HeaderField{ {Name: ":method", Value: http.MethodGet}, }, err: "invalid response pseudo header: :method", }, { name: "duplicate :status", headers: []qpack.HeaderField{ {Name: ":status", Value: "200"}, {Name: ":status", Value: "404"}, }, err: "duplicate pseudo header: :status", }, } { t.Run(tc.name, func(t *testing.T) { err := updateResponseFromHeaders(&http.Response{}, decodeFromSlice(tc.headers), math.MaxInt, nil) if tc.errContains != "" { require.ErrorContains(t, err, tc.errContains) } if tc.err != "" { require.EqualError(t, err, tc.err) } }) } for _, tc := range []struct { name string invalidField string }{ {name: "connection", invalidField: "connection"}, {name: "keep-alive", invalidField: "keep-alive"}, {name: "proxy-connection", invalidField: "proxy-connection"}, {name: "transfer-encoding", invalidField: "transfer-encoding"}, {name: "upgrade", invalidField: "upgrade"}, } { t.Run("invalid field: "+tc.name, func(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":status", Value: "404"}, {Name: tc.invalidField, Value: "some-value"}, } err := updateResponseFromHeaders(&http.Response{}, decodeFromSlice(headers), math.MaxInt, nil) require.EqualError(t, err, fmt.Sprintf("invalid header field name: %q", tc.invalidField)) }) } } func TestResponseTrailerFields(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":status", Value: "200"}, {Name: "trailer", Value: "Trailer1, Trailer2"}, {Name: "trailer", Value: "TRAILER3"}, } var rsp http.Response require.NoError(t, updateResponseFromHeaders(&rsp, decodeFromSlice(headers), math.MaxInt, nil)) require.Equal(t, 0, len(rsp.Header)) require.Equal(t, http.Header(map[string][]string{ "Trailer1": nil, "Trailer2": nil, "Trailer3": nil, }), rsp.Trailer) } func TestResponseTrailerParsingTE(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":status", Value: "404"}, {Name: "te", Value: "trailers"}, } require.NoError(t, updateResponseFromHeaders(&http.Response{}, decodeFromSlice(headers), math.MaxInt, nil)) headers = []qpack.HeaderField{ {Name: ":status", Value: "404"}, {Name: "te", Value: "not-trailers"}, } require.EqualError(t, updateResponseFromHeaders(&http.Response{}, decodeFromSlice(headers), math.MaxInt, nil), `invalid TE header field value: "not-trailers"`) } func TestResponseTrailerParsing(t *testing.T) { trailerHdr, err := parseTrailers(decodeFromSlice([]qpack.HeaderField{ {Name: "content-length", Value: "42"}, }), nil) require.NoError(t, err) require.Equal(t, "42", trailerHdr.Get("Content-Length")) } func TestResponseTrailerParsingValidation(t *testing.T) { headers := []qpack.HeaderField{ {Name: ":status", Value: "200"}, } _, err := parseTrailers(decodeFromSlice(headers), nil) require.EqualError(t, err, "http3: received pseudo header in trailer: :status") } func TestQpackError(t *testing.T) { buf := &bytes.Buffer{} enc := qpack.NewEncoder(buf) enc.WriteField(qpack.HeaderField{Name: ":status", Value: "200"}) enc.Close() t.Run("header parsing", func(t *testing.T) { dec := qpack.NewDecoder() decodeFn := dec.Decode(buf.Bytes()[:len(buf.Bytes())/2]) _, err := requestFromHeaders(decodeFn, math.MaxInt, nil) require.ErrorAs(t, err, new(*qpackError)) }) t.Run("trailer parsing", func(t *testing.T) { dec := qpack.NewDecoder() decodeFn := dec.Decode(buf.Bytes()[:len(buf.Bytes())/2]) err := updateResponseFromHeaders(&http.Response{}, decodeFn, math.MaxInt, nil) require.ErrorAs(t, err, new(*qpackError)) }) } func BenchmarkRequestFromHeaders(b *testing.B) { b.ReportAllocs() headers := []qpack.HeaderField{ {Name: ":path", Value: "/api/v1/users/12345"}, {Name: ":authority", Value: "quic-go.net"}, {Name: ":method", Value: http.MethodPost}, {Name: "content-type", Value: "application/json"}, {Name: "content-length", Value: "1024"}, {Name: "user-agent", Value: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/26.0 Safari/605.1.15"}, {Name: "accept", Value: "application/json, text/plain, */*"}, {Name: "accept-encoding", Value: "gzip, deflate, br"}, {Name: "accept-language", Value: "en-US,en;q=0.9"}, {Name: "cache-control", Value: "no-cache"}, {Name: "cookie", Value: "session_id=abc123"}, {Name: "cookie", Value: "user_pref=dark_mode"}, {Name: "referer", Value: "https://quic-go.net/docs/http3/"}, } var buf bytes.Buffer enc := qpack.NewEncoder(&buf) for _, hf := range headers { require.NoError(b, enc.WriteField(hf)) } dec := qpack.NewDecoder() for b.Loop() { decodeFn := dec.Decode(buf.Bytes()) if _, err := requestFromHeaders(decodeFn, math.MaxInt, nil); err != nil { b.Fatalf("failed to parse request: %v", err) } } } quic-go-0.59.0/http3/http3_helper_test.go000066400000000000000000000217611513066070600202150ustar00rootroot00000000000000package http3 import ( "bytes" "context" "crypto" "crypto/ed25519" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "errors" "io" "math/big" "net" "net/http" "os" "reflect" "strconv" "testing" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) // maxByteCount is the maximum value of a ByteCount const maxByteCount = uint64(1<<62 - 1) func newUDPConnLocalhost(t testing.TB) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) return conn } func scaleDuration(t time.Duration) time.Duration { scaleFactor := 1 if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set scaleFactor = f } if scaleFactor == 0 { panic("TIMESCALE_FACTOR is 0") } return time.Duration(scaleFactor) * t } var tlsConfig, tlsClientConfig *tls.Config func init() { ca, caPrivateKey, err := generateCA() if err != nil { panic(err) } leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) if err != nil { panic(err) } tlsConfig = &tls.Config{ Certificates: []tls.Certificate{{ Certificate: [][]byte{leafCert.Raw}, PrivateKey: leafPrivateKey, }}, NextProtos: []string{NextProtoH3}, } root := x509.NewCertPool() root.AddCert(ca) tlsClientConfig = &tls.Config{ ServerName: "localhost", RootCAs: root, NextProtos: []string{NextProtoH3}, } } func generateCA() (*x509.Certificate, crypto.PrivateKey, error) { certTempl := &x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: pkix.Name{}, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, err } caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, pub, priv) if err != nil { return nil, nil, err } ca, err := x509.ParseCertificate(caBytes) if err != nil { return nil, nil, err } return ca, priv, nil } func generateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Certificate, crypto.PrivateKey, error) { certTempl := &x509.Certificate{ SerialNumber: big.NewInt(1), DNSNames: []string{"localhost"}, IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, err } certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, pub, caPriv) if err != nil { return nil, nil, err } cert, err := x509.ParseCertificate(certBytes) if err != nil { return nil, nil, err } return cert, priv, nil } func getTLSConfig() *tls.Config { return tlsConfig.Clone() } func getTLSClientConfig() *tls.Config { return tlsClientConfig.Clone() } type qlogTrace struct { recorder qlogwriter.Recorder } func (t *qlogTrace) SupportsSchemas(schema string) bool { return true } func (t *qlogTrace) AddProducer() qlogwriter.Recorder { return t.recorder } type connPairOpts struct { clientRecorder qlogwriter.Recorder serverRecorder qlogwriter.Recorder serverBidiStreamLimit int64 enableDatagrams bool } type connPairOpt func(*connPairOpts) func withClientRecorder(r qlogwriter.Recorder) connPairOpt { return func(o *connPairOpts) { o.clientRecorder = r } } func withServerRecorder(r qlogwriter.Recorder) connPairOpt { return func(o *connPairOpts) { o.serverRecorder = r } } func withDatagrams() connPairOpt { return func(o *connPairOpts) { o.enableDatagrams = true } } func withServerBidiStreamLimit(limit int64) connPairOpt { return func(o *connPairOpts) { o.serverBidiStreamLimit = limit } } func newConnPair(t *testing.T, opts ...connPairOpt) (client, server *quic.Conn) { t.Helper() var o connPairOpts for _, opt := range opts { opt(&o) } ln, err := quic.ListenEarly( newUDPConnLocalhost(t), getTLSConfig(), &quic.Config{ InitialStreamReceiveWindow: maxByteCount, InitialConnectionReceiveWindow: maxByteCount, MaxIncomingStreams: o.serverBidiStreamLimit, EnableDatagrams: o.enableDatagrams, Tracer: func(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { return &qlogTrace{recorder: o.serverRecorder} }, }, ) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() cl, err := quic.DialEarly( ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), &quic.Config{ EnableDatagrams: o.enableDatagrams, Tracer: func(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { return &qlogTrace{recorder: o.clientRecorder} }, }, ) require.NoError(t, err) t.Cleanup(func() { cl.CloseWithError(0, "") }) conn, err := ln.Accept(ctx) require.NoError(t, err) t.Cleanup(func() { conn.CloseWithError(0, "") }) select { case <-conn.HandshakeComplete(): case <-ctx.Done(): t.Fatal("timeout") } return cl, conn } type quicReceiveStream interface { io.Reader SetReadDeadline(time.Time) error } func expectStreamReadReset(t *testing.T, str quicReceiveStream, errCode quic.StreamErrorCode) { t.Helper() str.SetReadDeadline(time.Now().Add(time.Second)) _, err := str.Read([]byte{0}) require.Error(t, err) if errors.Is(err, os.ErrDeadlineExceeded) { t.Fatal("didn't receive a stream reset") } var strErr *quic.StreamError require.ErrorAs(t, err, &strErr) require.Equal(t, errCode, strErr.ErrorCode) } type quicSendStream interface { io.Writer Context() context.Context } func expectStreamWriteReset(t *testing.T, str quicSendStream, errCode quic.StreamErrorCode) { t.Helper() select { case <-str.Context().Done(): case <-time.After(time.Second): t.Fatal("timeout") } _, err := str.Write([]byte{0}) require.Error(t, err) var strErr *quic.StreamError require.ErrorAs(t, err, &strErr) require.Equal(t, errCode, strErr.ErrorCode) } func encodeRequest(t *testing.T, req *http.Request) []byte { t.Helper() var buf bytes.Buffer rw := newRequestWriter() require.NoError(t, rw.WriteRequestHeader(&buf, req, false, 0, nil)) if req.Body != nil { body, err := io.ReadAll(req.Body) require.NoError(t, err) buf.Write((&dataFrame{Length: uint64(len(body))}).Append(nil)) buf.Write(body) } return buf.Bytes() } func decodeHeader(t *testing.T, r io.Reader) map[string][]string { t.Helper() fields := make(map[string][]string) frame, err := (&frameParser{r: r}).ParseNext(nil) require.NoError(t, err) require.IsType(t, &headersFrame{}, frame) headersFrame := frame.(*headersFrame) data := make([]byte, headersFrame.Length) _, err = io.ReadFull(r, data) require.NoError(t, err) hfs := decodeQpackHeaderFields(t, data) for _, p := range hfs { fields[p.Name] = append(fields[p.Name], p.Value) } return fields } func decodeQpackHeaderFields(t *testing.T, data []byte) []qpack.HeaderField { t.Helper() decoder := qpack.NewDecoder() decodeFn := decoder.Decode(data) var hfs []qpack.HeaderField for { hf, err := decodeFn() if err == io.EOF { break } require.NoError(t, err) hfs = append(hfs, hf) } return hfs } // filterQlogEventsForFrame filters the events for the given frame type, // for both FrameCreated and FrameParsed events. // It returns the events that match the given frame type. func filterQlogEventsForFrame(events []qlogwriter.Event, frame any) []qlogwriter.Event { var filtered []qlogwriter.Event for _, ev := range events { switch e := ev.(type) { case qlog.FrameCreated: if reflect.TypeOf(e.Frame.Frame) == reflect.TypeOf(frame) { filtered = append(filtered, ev) } case qlog.FrameParsed: if reflect.TypeOf(e.Frame.Frame) == reflect.TypeOf(frame) { filtered = append(filtered, ev) } } } return filtered } func expectedFrameLength(t *testing.T, frame any) (length, payloadLength int) { t.Helper() switch f := frame.(type) { case *dataFrame: return len(f.Append(nil)) + int(f.Length), int(f.Length) case *headersFrame: return len(f.Append(nil)) + int(f.Length), int(f.Length) case *goAwayFrame: return len(f.Append(nil)), quicvarint.Len(uint64(f.StreamID)) case *settingsFrame: data := f.Append(nil) r := bytes.NewReader(data) _, err := quicvarint.Read(r) // type require.NoError(t, err) _, err = quicvarint.Read(r) // length require.NoError(t, err) return len(data), r.Len() default: t.Fatalf("unexpected frame type: %T", frame) } panic("unreachable") } quic-go-0.59.0/http3/internal/000077500000000000000000000000001513066070600160335ustar00rootroot00000000000000quic-go-0.59.0/http3/internal/testdata/000077500000000000000000000000001513066070600176445ustar00rootroot00000000000000quic-go-0.59.0/http3/internal/testdata/ca.pem000066400000000000000000000020051513066070600207270ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIICzDCCAbQCCQDA+rLymNnfJzANBgkqhkiG9w0BAQsFADAoMSYwJAYDVQQKDB1x dWljLWdvIENlcnRpZmljYXRlIEF1dGhvcml0eTAeFw0yMDA4MTgwOTIxMzVaFw0z MDA4MTYwOTIxMzVaMCgxJjAkBgNVBAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0 aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1OcsYrVaSDfh iDppl6oteVspOY3yFb96T9Y/biaGPJAkBO9VGKcqwOUPmUeiWpedRAUB9LE7Srs6 qBX4mnl90Icjp8jbIs5cPgIWLkIu8Qm549RghFzB3bn+EmCQSe4cxvyDMN3ndClp 3YMXpZgXWgJGiPOylVi/OwHDdWDBorw4hvry+6yDtpQo2TuI2A/xtxXPT7BgsEJD WGffdgZOYXChcFA0c1XVLIYlu2w2JhxS8c2TUF6uSDlmcoONNKVoiNCuu1Z9MorS Qmg7a2G7dSPu123KcTcSQFcmJrt+1G81gOBtHB69kacD8xDmgksj09h/ODPL/gIU 1ZcU2ci1/QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB0Tb1JbLXp/BvWovSAhO/j wG7UEaUA1rCtkDB+fV2HS9bxCbV5eErdg8AMHKgB51ygUrq95vm/baZmUILr84XK uTEoxxrw5S9Z7SrhtbOpKCumoSeTsCPjDvCcwFExHv4XHFk+CPqZwbMHueVIMT0+ nGWss/KecCPdJLdnUgMRz0tIuXzkoRuOiUiZfUeyBNVNbDFSrLigYshTeAPGaYjX CypoHxkeS93nWfOMUu8FTYLYkvGMU5i076zDoFGKJiEtbjSiNW+Hei7u2aSEuCzp qyTKzYPWYffAq3MM2MKJgZdL04e9GEGeuce/qhM1o3q77aI/XJImwEDdut2LDec1 -----END CERTIFICATE----- quic-go-0.59.0/http3/internal/testdata/cert.go000066400000000000000000000023041513066070600211270ustar00rootroot00000000000000package testdata import ( "crypto/tls" "crypto/x509" "os" "path" "runtime" ) var certPath string func init() { _, filename, _, ok := runtime.Caller(0) if !ok { panic("Failed to get current frame") } certPath = path.Dir(filename) } // GetCertificatePaths returns the paths to certificate and key func GetCertificatePaths() (string, string) { return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key") } // GetTLSConfig returns a tls config for quic.clemente.io func GetTLSConfig() *tls.Config { cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) if err != nil { panic(err) } return &tls.Config{ MinVersion: tls.VersionTLS13, Certificates: []tls.Certificate{cert}, } } // AddRootCA adds the root CA certificate to a cert pool func AddRootCA(certPool *x509.CertPool) { caCertPath := path.Join(certPath, "ca.pem") caCertRaw, err := os.ReadFile(caCertPath) if err != nil { panic(err) } if ok := certPool.AppendCertsFromPEM(caCertRaw); !ok { panic("Could not add root ceritificate to pool.") } } // GetRootCA returns an x509.CertPool containing (only) the CA certificate func GetRootCA() *x509.CertPool { pool := x509.NewCertPool() AddRootCA(pool) return pool } quic-go-0.59.0/http3/internal/testdata/cert.pem000066400000000000000000000020221513066070600213000ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIC1TCCAb2gAwIBAgIJAK2fcqC0BVA7MA0GCSqGSIb3DQEBCwUAMCgxJjAkBgNV BAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTIwMDgxODA5MjEz NVoXDTMwMDgxNjA5MjEzNVowEjEQMA4GA1UECgwHcXVpYy1nbzCCASIwDQYJKoZI hvcNAQEBBQADggEPADCCAQoCggEBAN/YwrigSXdJCL/bdBGhb0UpqtU8H+krV870 +w1yCSykLImH8x3qHZEXt9sr/vgjcJoV6Z15RZmnbEqnAx84sIClIBoIgnk0VPxu WF+/U/dElbftCfYcfJAddhRckdmGB+yb3Wogb32UJ+q3my++h6NjHsYb+OwpJPnQ meXjOE7Kkf+bXfFywHF3R8kzVdh5JUFYeKbxYmYgxRps1YTsbCrZCrSy1CbQ9FJw Wg5C8t+7yvVFmOeWPECypBCz2xS2mu+kycMNIjIWMl0SL7oVM5cBkRKPeVIG/KcM i5+/4lRSLoPh0Txh2TKBWfpzLbIOdPU8/O7cAukIGWx0XsfHUQMCAwEAAaMYMBYw FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAyxxvebdMz shp5pt1SxMOSXbo8sTa1cpaf2rTmb4nxjXs6KPBEn53hSBz9bhe5wXE4f94SHadf 636rLh3d75KgrLUwO9Yq0HfCxMo1jUV/Ug++XwcHCI9vk58Tk/H4hqEM6C8RrdTj fYeuegQ0/oNLJ4uTw2P2A8TJbL6FC2dcICEAvUGZUcVyZ8m8tHXNRYYh6MZ7ubCh hinvL+AA5fY6EVlc5G/P4DN6fYxGn1cFNbiL4uZP4+W3dOmP+NV0YV9ihTyMzz0R vSoOZ9FeVkyw8EhMb3LoyXYKazvJy2VQST1ltzAGit9RiM1Gv4vuna74WsFzrn1U A/TbaR0ih/qG -----END CERTIFICATE----- quic-go-0.59.0/http3/internal/testdata/cert_test.go000066400000000000000000000011241513066070600221650ustar00rootroot00000000000000package testdata import ( "crypto/tls" "io" "testing" "github.com/stretchr/testify/require" ) func TestCertificates(t *testing.T) { ln, err := tls.Listen("tcp", "localhost:4433", GetTLSConfig()) require.NoError(t, err) go func() { conn, err := ln.Accept() require.NoError(t, err) defer conn.Close() _, err = conn.Write([]byte("foobar")) require.NoError(t, err) }() conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) require.NoError(t, err) data, err := io.ReadAll(conn) require.NoError(t, err) require.Equal(t, "foobar", string(data)) } quic-go-0.59.0/http3/internal/testdata/generate_key.sh000077500000000000000000000012441513066070600226460ustar00rootroot00000000000000#!/bin/bash set -e echo "Generating CA key and certificate:" openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:2048 \ -keyout ca.key -out ca.pem \ -subj "/O=quic-go Certificate Authority/" echo "Generating CSR" openssl req -out cert.csr -new -newkey rsa:2048 -nodes -keyout priv.key \ -subj "/O=quic-go/" echo "Sign certificate:" openssl x509 -req -sha256 -days 3650 -in cert.csr -out cert.pem \ -CA ca.pem -CAkey ca.key -CAcreateserial \ -extfile <(printf "subjectAltName=DNS:localhost") # debug output the certificate openssl x509 -noout -text -in cert.pem # we don't need the CA key, the serial number and the CSR any more rm ca.key cert.csr ca.srl quic-go-0.59.0/http3/internal/testdata/priv.key000066400000000000000000000032501513066070600213360ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDf2MK4oEl3SQi/ 23QRoW9FKarVPB/pK1fO9PsNcgkspCyJh/Md6h2RF7fbK/74I3CaFemdeUWZp2xK pwMfOLCApSAaCIJ5NFT8blhfv1P3RJW37Qn2HHyQHXYUXJHZhgfsm91qIG99lCfq t5svvoejYx7GG/jsKST50Jnl4zhOypH/m13xcsBxd0fJM1XYeSVBWHim8WJmIMUa bNWE7Gwq2Qq0stQm0PRScFoOQvLfu8r1RZjnljxAsqQQs9sUtprvpMnDDSIyFjJd Ei+6FTOXAZESj3lSBvynDIufv+JUUi6D4dE8YdkygVn6cy2yDnT1PPzu3ALpCBls dF7Hx1EDAgMBAAECggEBAMm+mLDBdbUWk9YmuZNyRdC13wvT5obF05vo26OglXgw dxt09b6OVBuCnuff3SpS9pdJDIYq2HnFlSorH/sxopIvQKF17fHDIp1n7ipNTCXd IHrmHkY8Il/YzaVIUQMVc2rih0mw9greTqOS20DKnYC6QvAWIeDmrDaitTGl+ge3 hm7e2lsgZi13R6fTNwQs9geEQSGzP2k7bFceHQFDChOYiQraR5+VZZ8S8AMGjk47 AUa5EsKeUe6O9t2xuDSFxzYz5eadOAiErKGDos5KXXr3VQgFcC8uPEFFjcJ/yl+8 tOe4iLeVwGSDJhTAThdR2deJOjaDcarWM7ixmxA3DAECgYEA/WVwmY4gWKwv49IJ Jnh1Gu93P772GqliMNpukdjTI+joQxfl4jRSt2hk4b1KRwyT9aaKfvdz0HFlXo/r 9NVSAYT3/3vbcw61bfvPhhtz44qRAAKua6b5cUM6XqxVt1hqdP8lrf/blvA5ln+u O51S8+wpxZMuqKz/29zdWSG6tAMCgYEA4iWXMXX9dZajI6abVkWwuosvOakXdLk4 tUy7zd+JPF7hmUzzj2gtg4hXoiQPAOi+GY3TX+1Nza3s1LD7iWaXSKeOWvvligw9 Q/wVTNW2P1+tdhScJf9QudzW69xOm5HNBgx9uWV2cHfjC12vg5aTH0k5axvaq15H 9WBXlH5q3wECgYBYoYGYBDFmMpvxmMagkSOMz1OrlVSpkLOKmOxx0SBRACc1SIec 7mY8RqR6nOX9IfYixyTMMittLiyhvb9vfKnZZDQGRcFFZlCpbplws+t+HDqJgWaW uumm5zfkY2z7204pLBF24fZhvha2gGRl76pTLTiTJd79Gr3HnmJByd1vFwKBgHL7 vfYuEeM55lT4Hz8sTAFtR2O/7+cvTgAQteSlZbfGXlp939DonUulhTkxsFc7/3wq unCpzcdoSWSTYDGqcf1FBIKKVVltg7EPeR0KBJIQabgCHqrLOBZojPZ7m5RJ+765 lysuxZvFuTFMPzNe2gssRf+JuBMt6tR+WclsxZYBAoGAEEFs1ppDil1xlP5rdH7T d3TSw/u4eU/X8Ei1zi25hdRUiV76fP9fBELYFmSrPBhugYv91vtSv/LmD4zLfLv/ yzwAD9j1lGbgM8Of8klCkk+XSJ88ryUwnMTJ5loQJW8t4L+zLv5Le7Ca9SAT0kJ1 jT0GzDymgLMGp8RPdBkpk+w= -----END PRIVATE KEY----- quic-go-0.59.0/http3/ip_addr.go000066400000000000000000000025021513066070600161470ustar00rootroot00000000000000package http3 import ( "net" "strings" ) // An addrList represents a list of network endpoint addresses. // Copy from [net.addrList] and change type from [net.Addr] to [net.IPAddr] type addrList []net.IPAddr // isIPv4 reports whether addr contains an IPv4 address. func isIPv4(addr net.IPAddr) bool { return addr.IP.To4() != nil } // isNotIPv4 reports whether addr does not contain an IPv4 address. func isNotIPv4(addr net.IPAddr) bool { return !isIPv4(addr) } // forResolve returns the most appropriate address in address for // a call to ResolveTCPAddr, ResolveUDPAddr, or ResolveIPAddr. // IPv4 is preferred, unless addr contains an IPv6 literal. func (addrs addrList) forResolve(network, addr string) net.IPAddr { var want6 bool switch network { case "ip": // IPv6 literal (addr does NOT contain a port) want6 = strings.ContainsRune(addr, ':') case "tcp", "udp": // IPv6 literal. (addr contains a port, so look for '[') want6 = strings.ContainsRune(addr, '[') } if want6 { return addrs.first(isNotIPv4) } return addrs.first(isIPv4) } // first returns the first address which satisfies strategy, or if // none do, then the first address of any kind. func (addrs addrList) first(strategy func(net.IPAddr) bool) net.IPAddr { for _, addr := range addrs { if strategy(addr) { return addr } } return addrs[0] } quic-go-0.59.0/http3/mock_clientconn_test.go000066400000000000000000000125561513066070600207630ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/http3 (interfaces: TestClientConnInterface) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -mock_names=TestClientConnInterface=MockClientConn -package http3 -destination mock_clientconn_test.go github.com/quic-go/quic-go/http3 TestClientConnInterface // // Package http3 is a generated GoMock package. package http3 import ( context "context" http "net/http" reflect "reflect" quic "github.com/quic-go/quic-go" gomock "go.uber.org/mock/gomock" ) // MockClientConn is a mock of TestClientConnInterface interface. type MockClientConn struct { ctrl *gomock.Controller recorder *MockClientConnMockRecorder isgomock struct{} } // MockClientConnMockRecorder is the mock recorder for MockClientConn. type MockClientConnMockRecorder struct { mock *MockClientConn } // NewMockClientConn creates a new mock instance. func NewMockClientConn(ctrl *gomock.Controller) *MockClientConn { mock := &MockClientConn{ctrl: ctrl} mock.recorder = &MockClientConnMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockClientConn) EXPECT() *MockClientConnMockRecorder { return m.recorder } // OpenRequestStream mocks base method. func (m *MockClientConn) OpenRequestStream(arg0 context.Context) (*RequestStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenRequestStream", arg0) ret0, _ := ret[0].(*RequestStream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenRequestStream indicates an expected call of OpenRequestStream. func (mr *MockClientConnMockRecorder) OpenRequestStream(arg0 any) *MockClientConnOpenRequestStreamCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenRequestStream", reflect.TypeOf((*MockClientConn)(nil).OpenRequestStream), arg0) return &MockClientConnOpenRequestStreamCall{Call: call} } // MockClientConnOpenRequestStreamCall wrap *gomock.Call type MockClientConnOpenRequestStreamCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockClientConnOpenRequestStreamCall) Return(arg0 *RequestStream, arg1 error) *MockClientConnOpenRequestStreamCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockClientConnOpenRequestStreamCall) Do(f func(context.Context) (*RequestStream, error)) *MockClientConnOpenRequestStreamCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockClientConnOpenRequestStreamCall) DoAndReturn(f func(context.Context) (*RequestStream, error)) *MockClientConnOpenRequestStreamCall { c.Call = c.Call.DoAndReturn(f) return c } // RoundTrip mocks base method. func (m *MockClientConn) RoundTrip(arg0 *http.Request) (*http.Response, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RoundTrip", arg0) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 } // RoundTrip indicates an expected call of RoundTrip. func (mr *MockClientConnMockRecorder) RoundTrip(arg0 any) *MockClientConnRoundTripCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTrip", reflect.TypeOf((*MockClientConn)(nil).RoundTrip), arg0) return &MockClientConnRoundTripCall{Call: call} } // MockClientConnRoundTripCall wrap *gomock.Call type MockClientConnRoundTripCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockClientConnRoundTripCall) Return(arg0 *http.Response, arg1 error) *MockClientConnRoundTripCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockClientConnRoundTripCall) Do(f func(*http.Request) (*http.Response, error)) *MockClientConnRoundTripCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockClientConnRoundTripCall) DoAndReturn(f func(*http.Request) (*http.Response, error)) *MockClientConnRoundTripCall { c.Call = c.Call.DoAndReturn(f) return c } // handleUnidirectionalStream mocks base method. func (m *MockClientConn) handleUnidirectionalStream(arg0 *quic.ReceiveStream) { m.ctrl.T.Helper() m.ctrl.Call(m, "handleUnidirectionalStream", arg0) } // handleUnidirectionalStream indicates an expected call of handleUnidirectionalStream. func (mr *MockClientConnMockRecorder) handleUnidirectionalStream(arg0 any) *MockClientConnhandleUnidirectionalStreamCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleUnidirectionalStream", reflect.TypeOf((*MockClientConn)(nil).handleUnidirectionalStream), arg0) return &MockClientConnhandleUnidirectionalStreamCall{Call: call} } // MockClientConnhandleUnidirectionalStreamCall wrap *gomock.Call type MockClientConnhandleUnidirectionalStreamCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockClientConnhandleUnidirectionalStreamCall) Return() *MockClientConnhandleUnidirectionalStreamCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockClientConnhandleUnidirectionalStreamCall) Do(f func(*quic.ReceiveStream)) *MockClientConnhandleUnidirectionalStreamCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockClientConnhandleUnidirectionalStreamCall) DoAndReturn(f func(*quic.ReceiveStream)) *MockClientConnhandleUnidirectionalStreamCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/http3/mock_datagram_stream_test.go000066400000000000000000000421271513066070600217570ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/http3 (interfaces: DatagramStream) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -mock_names=DatagramStream=MockDatagramStream -package http3 -destination mock_datagram_stream_test.go github.com/quic-go/quic-go/http3 DatagramStream // // Package http3 is a generated GoMock package. package http3 import ( context "context" reflect "reflect" time "time" quic "github.com/quic-go/quic-go" gomock "go.uber.org/mock/gomock" ) // MockDatagramStream is a mock of DatagramStream interface. type MockDatagramStream struct { ctrl *gomock.Controller recorder *MockDatagramStreamMockRecorder isgomock struct{} } // MockDatagramStreamMockRecorder is the mock recorder for MockDatagramStream. type MockDatagramStreamMockRecorder struct { mock *MockDatagramStream } // NewMockDatagramStream creates a new mock instance. func NewMockDatagramStream(ctrl *gomock.Controller) *MockDatagramStream { mock := &MockDatagramStream{ctrl: ctrl} mock.recorder = &MockDatagramStreamMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockDatagramStream) EXPECT() *MockDatagramStreamMockRecorder { return m.recorder } // CancelRead mocks base method. func (m *MockDatagramStream) CancelRead(arg0 quic.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelRead", arg0) } // CancelRead indicates an expected call of CancelRead. func (mr *MockDatagramStreamMockRecorder) CancelRead(arg0 any) *MockDatagramStreamCancelReadCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockDatagramStream)(nil).CancelRead), arg0) return &MockDatagramStreamCancelReadCall{Call: call} } // MockDatagramStreamCancelReadCall wrap *gomock.Call type MockDatagramStreamCancelReadCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamCancelReadCall) Return() *MockDatagramStreamCancelReadCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamCancelReadCall) Do(f func(quic.StreamErrorCode)) *MockDatagramStreamCancelReadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamCancelReadCall) DoAndReturn(f func(quic.StreamErrorCode)) *MockDatagramStreamCancelReadCall { c.Call = c.Call.DoAndReturn(f) return c } // CancelWrite mocks base method. func (m *MockDatagramStream) CancelWrite(arg0 quic.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelWrite", arg0) } // CancelWrite indicates an expected call of CancelWrite. func (mr *MockDatagramStreamMockRecorder) CancelWrite(arg0 any) *MockDatagramStreamCancelWriteCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockDatagramStream)(nil).CancelWrite), arg0) return &MockDatagramStreamCancelWriteCall{Call: call} } // MockDatagramStreamCancelWriteCall wrap *gomock.Call type MockDatagramStreamCancelWriteCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamCancelWriteCall) Return() *MockDatagramStreamCancelWriteCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamCancelWriteCall) Do(f func(quic.StreamErrorCode)) *MockDatagramStreamCancelWriteCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamCancelWriteCall) DoAndReturn(f func(quic.StreamErrorCode)) *MockDatagramStreamCancelWriteCall { c.Call = c.Call.DoAndReturn(f) return c } // Close mocks base method. func (m *MockDatagramStream) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. func (mr *MockDatagramStreamMockRecorder) Close() *MockDatagramStreamCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatagramStream)(nil).Close)) return &MockDatagramStreamCloseCall{Call: call} } // MockDatagramStreamCloseCall wrap *gomock.Call type MockDatagramStreamCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamCloseCall) Return(arg0 error) *MockDatagramStreamCloseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamCloseCall) Do(f func() error) *MockDatagramStreamCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamCloseCall) DoAndReturn(f func() error) *MockDatagramStreamCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // Context mocks base method. func (m *MockDatagramStream) Context() context.Context { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Context") ret0, _ := ret[0].(context.Context) return ret0 } // Context indicates an expected call of Context. func (mr *MockDatagramStreamMockRecorder) Context() *MockDatagramStreamContextCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockDatagramStream)(nil).Context)) return &MockDatagramStreamContextCall{Call: call} } // MockDatagramStreamContextCall wrap *gomock.Call type MockDatagramStreamContextCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamContextCall) Return(arg0 context.Context) *MockDatagramStreamContextCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamContextCall) Do(f func() context.Context) *MockDatagramStreamContextCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamContextCall) DoAndReturn(f func() context.Context) *MockDatagramStreamContextCall { c.Call = c.Call.DoAndReturn(f) return c } // QUICStream mocks base method. func (m *MockDatagramStream) QUICStream() *quic.Stream { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QUICStream") ret0, _ := ret[0].(*quic.Stream) return ret0 } // QUICStream indicates an expected call of QUICStream. func (mr *MockDatagramStreamMockRecorder) QUICStream() *MockDatagramStreamQUICStreamCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QUICStream", reflect.TypeOf((*MockDatagramStream)(nil).QUICStream)) return &MockDatagramStreamQUICStreamCall{Call: call} } // MockDatagramStreamQUICStreamCall wrap *gomock.Call type MockDatagramStreamQUICStreamCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamQUICStreamCall) Return(arg0 *quic.Stream) *MockDatagramStreamQUICStreamCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamQUICStreamCall) Do(f func() *quic.Stream) *MockDatagramStreamQUICStreamCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamQUICStreamCall) DoAndReturn(f func() *quic.Stream) *MockDatagramStreamQUICStreamCall { c.Call = c.Call.DoAndReturn(f) return c } // Read mocks base method. func (m *MockDatagramStream) Read(p []byte) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Read", p) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // Read indicates an expected call of Read. func (mr *MockDatagramStreamMockRecorder) Read(p any) *MockDatagramStreamReadCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockDatagramStream)(nil).Read), p) return &MockDatagramStreamReadCall{Call: call} } // MockDatagramStreamReadCall wrap *gomock.Call type MockDatagramStreamReadCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamReadCall) Return(n int, err error) *MockDatagramStreamReadCall { c.Call = c.Call.Return(n, err) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamReadCall) Do(f func([]byte) (int, error)) *MockDatagramStreamReadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamReadCall) DoAndReturn(f func([]byte) (int, error)) *MockDatagramStreamReadCall { c.Call = c.Call.DoAndReturn(f) return c } // ReceiveDatagram mocks base method. func (m *MockDatagramStream) ReceiveDatagram(ctx context.Context) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReceiveDatagram", ctx) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // ReceiveDatagram indicates an expected call of ReceiveDatagram. func (mr *MockDatagramStreamMockRecorder) ReceiveDatagram(ctx any) *MockDatagramStreamReceiveDatagramCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockDatagramStream)(nil).ReceiveDatagram), ctx) return &MockDatagramStreamReceiveDatagramCall{Call: call} } // MockDatagramStreamReceiveDatagramCall wrap *gomock.Call type MockDatagramStreamReceiveDatagramCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamReceiveDatagramCall) Return(arg0 []byte, arg1 error) *MockDatagramStreamReceiveDatagramCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamReceiveDatagramCall) Do(f func(context.Context) ([]byte, error)) *MockDatagramStreamReceiveDatagramCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamReceiveDatagramCall) DoAndReturn(f func(context.Context) ([]byte, error)) *MockDatagramStreamReceiveDatagramCall { c.Call = c.Call.DoAndReturn(f) return c } // SendDatagram mocks base method. func (m *MockDatagramStream) SendDatagram(b []byte) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendDatagram", b) ret0, _ := ret[0].(error) return ret0 } // SendDatagram indicates an expected call of SendDatagram. func (mr *MockDatagramStreamMockRecorder) SendDatagram(b any) *MockDatagramStreamSendDatagramCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockDatagramStream)(nil).SendDatagram), b) return &MockDatagramStreamSendDatagramCall{Call: call} } // MockDatagramStreamSendDatagramCall wrap *gomock.Call type MockDatagramStreamSendDatagramCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamSendDatagramCall) Return(arg0 error) *MockDatagramStreamSendDatagramCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamSendDatagramCall) Do(f func([]byte) error) *MockDatagramStreamSendDatagramCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamSendDatagramCall) DoAndReturn(f func([]byte) error) *MockDatagramStreamSendDatagramCall { c.Call = c.Call.DoAndReturn(f) return c } // SetDeadline mocks base method. func (m *MockDatagramStream) SetDeadline(arg0 time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetDeadline indicates an expected call of SetDeadline. func (mr *MockDatagramStreamMockRecorder) SetDeadline(arg0 any) *MockDatagramStreamSetDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockDatagramStream)(nil).SetDeadline), arg0) return &MockDatagramStreamSetDeadlineCall{Call: call} } // MockDatagramStreamSetDeadlineCall wrap *gomock.Call type MockDatagramStreamSetDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamSetDeadlineCall) Return(arg0 error) *MockDatagramStreamSetDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamSetDeadlineCall) Do(f func(time.Time) error) *MockDatagramStreamSetDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamSetDeadlineCall) DoAndReturn(f func(time.Time) error) *MockDatagramStreamSetDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // SetReadDeadline mocks base method. func (m *MockDatagramStream) SetReadDeadline(arg0 time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetReadDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetReadDeadline indicates an expected call of SetReadDeadline. func (mr *MockDatagramStreamMockRecorder) SetReadDeadline(arg0 any) *MockDatagramStreamSetReadDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockDatagramStream)(nil).SetReadDeadline), arg0) return &MockDatagramStreamSetReadDeadlineCall{Call: call} } // MockDatagramStreamSetReadDeadlineCall wrap *gomock.Call type MockDatagramStreamSetReadDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamSetReadDeadlineCall) Return(arg0 error) *MockDatagramStreamSetReadDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamSetReadDeadlineCall) Do(f func(time.Time) error) *MockDatagramStreamSetReadDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamSetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockDatagramStreamSetReadDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // SetWriteDeadline mocks base method. func (m *MockDatagramStream) SetWriteDeadline(arg0 time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetWriteDeadline indicates an expected call of SetWriteDeadline. func (mr *MockDatagramStreamMockRecorder) SetWriteDeadline(arg0 any) *MockDatagramStreamSetWriteDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockDatagramStream)(nil).SetWriteDeadline), arg0) return &MockDatagramStreamSetWriteDeadlineCall{Call: call} } // MockDatagramStreamSetWriteDeadlineCall wrap *gomock.Call type MockDatagramStreamSetWriteDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamSetWriteDeadlineCall) Return(arg0 error) *MockDatagramStreamSetWriteDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamSetWriteDeadlineCall) Do(f func(time.Time) error) *MockDatagramStreamSetWriteDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamSetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockDatagramStreamSetWriteDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // StreamID mocks base method. func (m *MockDatagramStream) StreamID() quic.StreamID { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StreamID") ret0, _ := ret[0].(quic.StreamID) return ret0 } // StreamID indicates an expected call of StreamID. func (mr *MockDatagramStreamMockRecorder) StreamID() *MockDatagramStreamStreamIDCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockDatagramStream)(nil).StreamID)) return &MockDatagramStreamStreamIDCall{Call: call} } // MockDatagramStreamStreamIDCall wrap *gomock.Call type MockDatagramStreamStreamIDCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamStreamIDCall) Return(arg0 quic.StreamID) *MockDatagramStreamStreamIDCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamStreamIDCall) Do(f func() quic.StreamID) *MockDatagramStreamStreamIDCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamStreamIDCall) DoAndReturn(f func() quic.StreamID) *MockDatagramStreamStreamIDCall { c.Call = c.Call.DoAndReturn(f) return c } // Write mocks base method. func (m *MockDatagramStream) Write(p []byte) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", p) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // Write indicates an expected call of Write. func (mr *MockDatagramStreamMockRecorder) Write(p any) *MockDatagramStreamWriteCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockDatagramStream)(nil).Write), p) return &MockDatagramStreamWriteCall{Call: call} } // MockDatagramStreamWriteCall wrap *gomock.Call type MockDatagramStreamWriteCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockDatagramStreamWriteCall) Return(n int, err error) *MockDatagramStreamWriteCall { c.Call = c.Call.Return(n, err) return c } // Do rewrite *gomock.Call.Do func (c *MockDatagramStreamWriteCall) Do(f func([]byte) (int, error)) *MockDatagramStreamWriteCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockDatagramStreamWriteCall) DoAndReturn(f func([]byte) (int, error)) *MockDatagramStreamWriteCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/http3/mock_quic_listener_test.go000066400000000000000000000110121513066070600214570ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/http3 (interfaces: QUICListener) // // Generated by this command: // // mockgen -typed -package http3 -destination mock_quic_listener_test.go github.com/quic-go/quic-go/http3 QUICListener // // Package http3 is a generated GoMock package. package http3 import ( context "context" net "net" reflect "reflect" quic "github.com/quic-go/quic-go" gomock "go.uber.org/mock/gomock" ) // MockQUICListener is a mock of QUICListener interface. type MockQUICListener struct { ctrl *gomock.Controller recorder *MockQUICListenerMockRecorder isgomock struct{} } // MockQUICListenerMockRecorder is the mock recorder for MockQUICListener. type MockQUICListenerMockRecorder struct { mock *MockQUICListener } // NewMockQUICListener creates a new mock instance. func NewMockQUICListener(ctrl *gomock.Controller) *MockQUICListener { mock := &MockQUICListener{ctrl: ctrl} mock.recorder = &MockQUICListenerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockQUICListener) EXPECT() *MockQUICListenerMockRecorder { return m.recorder } // Accept mocks base method. func (m *MockQUICListener) Accept(arg0 context.Context) (*quic.Conn, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Accept", arg0) ret0, _ := ret[0].(*quic.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // Accept indicates an expected call of Accept. func (mr *MockQUICListenerMockRecorder) Accept(arg0 any) *MockQUICListenerAcceptCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockQUICListener)(nil).Accept), arg0) return &MockQUICListenerAcceptCall{Call: call} } // MockQUICListenerAcceptCall wrap *gomock.Call type MockQUICListenerAcceptCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockQUICListenerAcceptCall) Return(arg0 *quic.Conn, arg1 error) *MockQUICListenerAcceptCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockQUICListenerAcceptCall) Do(f func(context.Context) (*quic.Conn, error)) *MockQUICListenerAcceptCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockQUICListenerAcceptCall) DoAndReturn(f func(context.Context) (*quic.Conn, error)) *MockQUICListenerAcceptCall { c.Call = c.Call.DoAndReturn(f) return c } // Addr mocks base method. func (m *MockQUICListener) Addr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Addr") ret0, _ := ret[0].(net.Addr) return ret0 } // Addr indicates an expected call of Addr. func (mr *MockQUICListenerMockRecorder) Addr() *MockQUICListenerAddrCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockQUICListener)(nil).Addr)) return &MockQUICListenerAddrCall{Call: call} } // MockQUICListenerAddrCall wrap *gomock.Call type MockQUICListenerAddrCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockQUICListenerAddrCall) Return(arg0 net.Addr) *MockQUICListenerAddrCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockQUICListenerAddrCall) Do(f func() net.Addr) *MockQUICListenerAddrCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockQUICListenerAddrCall) DoAndReturn(f func() net.Addr) *MockQUICListenerAddrCall { c.Call = c.Call.DoAndReturn(f) return c } // Close mocks base method. func (m *MockQUICListener) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. func (mr *MockQUICListenerMockRecorder) Close() *MockQUICListenerCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQUICListener)(nil).Close)) return &MockQUICListenerCloseCall{Call: call} } // MockQUICListenerCloseCall wrap *gomock.Call type MockQUICListenerCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockQUICListenerCloseCall) Return(arg0 error) *MockQUICListenerCloseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockQUICListenerCloseCall) Do(f func() error) *MockQUICListenerCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockQUICListenerCloseCall) DoAndReturn(f func() error) *MockQUICListenerCloseCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/http3/mockgen.go000066400000000000000000000013401513066070600161670ustar00rootroot00000000000000//go:build gomock || generate package http3 //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -mock_names=TestClientConnInterface=MockClientConn -package http3 -destination mock_clientconn_test.go github.com/quic-go/quic-go/http3 TestClientConnInterface" type TestClientConnInterface = clientConn //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -mock_names=DatagramStream=MockDatagramStream -package http3 -destination mock_datagram_stream_test.go github.com/quic-go/quic-go/http3 DatagramStream" type DatagramStream = datagramStream //go:generate sh -c "go tool mockgen -typed -package http3 -destination mock_quic_listener_test.go github.com/quic-go/quic-go/http3 QUICListener" quic-go-0.59.0/http3/qlog.go000066400000000000000000000027751513066070600155230ustar00rootroot00000000000000package http3 import ( "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/qpack" ) func maybeQlogInvalidHeadersFrame(qlogger qlogwriter.Recorder, streamID quic.StreamID, l uint64) { if qlogger != nil { qlogger.RecordEvent(qlog.FrameParsed{ StreamID: streamID, Raw: qlog.RawInfo{PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.HeadersFrame{}}, }) } } func qlogParsedHeadersFrame(qlogger qlogwriter.Recorder, streamID quic.StreamID, hf *headersFrame, hfs []qpack.HeaderField) { headerFields := make([]qlog.HeaderField, len(hfs)) for i, hf := range hfs { headerFields[i] = qlog.HeaderField{ Name: hf.Name, Value: hf.Value, } } qlogger.RecordEvent(qlog.FrameParsed{ StreamID: streamID, Raw: qlog.RawInfo{ Length: int(hf.Length) + hf.headerLen, PayloadLength: int(hf.Length), }, Frame: qlog.Frame{Frame: qlog.HeadersFrame{ HeaderFields: headerFields, }}, }) } func qlogCreatedHeadersFrame(qlogger qlogwriter.Recorder, streamID quic.StreamID, length, payloadLength int, hfs []qlog.HeaderField) { headerFields := make([]qlog.HeaderField, len(hfs)) for i, hf := range hfs { headerFields[i] = qlog.HeaderField{ Name: hf.Name, Value: hf.Value, } } qlogger.RecordEvent(qlog.FrameCreated{ StreamID: streamID, Raw: qlog.RawInfo{Length: length, PayloadLength: payloadLength}, Frame: qlog.Frame{Frame: qlog.HeadersFrame{ HeaderFields: headerFields, }}, }) } quic-go-0.59.0/http3/qlog/000077500000000000000000000000001513066070600151615ustar00rootroot00000000000000quic-go-0.59.0/http3/qlog/event.go000066400000000000000000000065151513066070600166400ustar00rootroot00000000000000package qlog import ( "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) type encoderHelper struct { enc *jsontext.Encoder err error } func (h *encoderHelper) WriteToken(t jsontext.Token) { if h.err != nil { return } h.err = h.enc.WriteToken(t) } type RawInfo struct { Length int // full packet length, including header and AEAD authentication tag PayloadLength int // length of the packet payload, excluding AEAD tag } func (i RawInfo) HasValues() bool { return i.Length != 0 || i.PayloadLength != 0 } func (i RawInfo) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if i.Length != 0 { h.WriteToken(jsontext.String("length")) h.WriteToken(jsontext.Uint(uint64(i.Length))) } if i.PayloadLength != 0 { h.WriteToken(jsontext.String("payload_length")) h.WriteToken(jsontext.Uint(uint64(i.PayloadLength))) } h.WriteToken(jsontext.EndObject) return h.err } type FrameParsed struct { StreamID quic.StreamID Raw RawInfo Frame Frame } func (e FrameParsed) Name() string { return "http3:frame_parsed" } func (e FrameParsed) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Uint(uint64(e.StreamID))) if e.Raw.HasValues() { h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } } h.WriteToken(jsontext.String("frame")) if err := e.Frame.encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } type FrameCreated struct { StreamID quic.StreamID Raw RawInfo Frame Frame } func (e FrameCreated) Name() string { return "http3:frame_created" } func (e FrameCreated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Uint(uint64(e.StreamID))) if e.Raw.HasValues() { h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } } h.WriteToken(jsontext.String("frame")) if err := e.Frame.encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } type DatagramCreated struct { QuaterStreamID uint64 Raw RawInfo } func (e DatagramCreated) Name() string { return "http3:datagram_created" } func (e DatagramCreated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("quater_stream_id")) h.WriteToken(jsontext.Uint(e.QuaterStreamID)) h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } type DatagramParsed struct { QuaterStreamID uint64 Raw RawInfo } func (e DatagramParsed) Name() string { return "http3:datagram_parsed" } func (e DatagramParsed) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("quater_stream_id")) h.WriteToken(jsontext.Uint(e.QuaterStreamID)) h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } quic-go-0.59.0/http3/qlog/event_test.go000066400000000000000000000046241513066070600176760ustar00rootroot00000000000000package qlog import ( "bytes" "encoding/json" "io" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/require" ) type nopWriteCloserImpl struct{ io.Writer } func (nopWriteCloserImpl) Close() error { return nil } func nopWriteCloser(w io.Writer) io.WriteCloser { return &nopWriteCloserImpl{Writer: w} } func testEventEncoding(t *testing.T, ev qlogwriter.Event) (string, map[string]any) { t.Helper() var buf bytes.Buffer synctest.Test(t, func(t *testing.T) { tr := qlogwriter.NewConnectionFileSeq( nopWriteCloser(&buf), true, quic.ConnectionIDFromBytes([]byte{1, 2, 3, 4}), []string{"http3"}, ) go tr.Run() producer := tr.AddProducer() synctest.Wait() time.Sleep(42 * time.Second) producer.RecordEvent(ev) producer.Close() }) return decode(t, buf.String()) } func decode(t *testing.T, data string) (string, map[string]any) { t.Helper() var result map[string]any lines := bytes.Split([]byte(data), []byte{'\n'}) require.Len(t, lines, 3) // the first line is the trace header, the second line is the event, the third line is empty require.Empty(t, lines[2]) require.Equal(t, qlogwriter.RecordSeparator, lines[1][0], "expected record separator at start of line") require.NoError(t, json.Unmarshal(lines[1][1:], &result)) require.Equal(t, 42*time.Second, time.Duration(result["time"].(float64)*1e6)*time.Nanosecond) return result["name"].(string), result["data"].(map[string]any) } func TestFrameParsedEvent(t *testing.T) { name, ev := testEventEncoding(t, FrameParsed{ StreamID: quic.StreamID(4), Raw: RawInfo{ Length: 1500, PayloadLength: 100, }, Frame: Frame{Frame: &DataFrame{}}, }) require.Equal(t, "http3:frame_parsed", name) require.Equal(t, float64(4), ev["stream_id"]) require.NotContains(t, ev, "name") require.Contains(t, ev, "frame") } func TestFrameCreatedEvent(t *testing.T) { name, ev := testEventEncoding(t, FrameCreated{ StreamID: quic.StreamID(8), Raw: RawInfo{ PayloadLength: 200, }, Frame: Frame{Frame: &HeadersFrame{ HeaderFields: []HeaderField{ {Name: ":status", Value: "200"}, {Name: "content-type", Value: "text/html"}, }, }}, }) require.Equal(t, "http3:frame_created", name) require.Equal(t, float64(8), ev["stream_id"]) require.NotContains(t, ev, "name") require.Contains(t, ev, "frame") } quic-go-0.59.0/http3/qlog/frame.go000066400000000000000000000140621513066070600166050ustar00rootroot00000000000000package qlog import ( "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) // Frame represents an HTTP/3 frame. type Frame struct { Frame any } func (f Frame) encode(enc *jsontext.Encoder) error { switch frame := f.Frame.(type) { case DataFrame: return frame.encode(enc) case HeadersFrame: return frame.encode(enc) case GoAwayFrame: return frame.encode(enc) case SettingsFrame: return frame.encode(enc) case PushPromiseFrame: return frame.encode(enc) case CancelPushFrame: return frame.encode(enc) case MaxPushIDFrame: return frame.encode(enc) case ReservedFrame: return frame.encode(enc) case UnknownFrame: return frame.encode(enc) } // This shouldn't happen if the code is correctly logging frames. // Write a null token to produce valid JSON. return enc.WriteToken(jsontext.Null) } // A DataFrame is a DATA frame type DataFrame struct{} func (f *DataFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("data")) h.WriteToken(jsontext.EndObject) return h.err } type HeaderField struct { Name string Value string } // A HeadersFrame is a HEADERS frame type HeadersFrame struct { HeaderFields []HeaderField } func (f *HeadersFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("headers")) if len(f.HeaderFields) > 0 { h.WriteToken(jsontext.String("header_fields")) h.WriteToken(jsontext.BeginArray) for _, f := range f.HeaderFields { h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("name")) h.WriteToken(jsontext.String(f.Name)) h.WriteToken(jsontext.String("value")) h.WriteToken(jsontext.String(f.Value)) h.WriteToken(jsontext.EndObject) } h.WriteToken(jsontext.EndArray) } h.WriteToken(jsontext.EndObject) return h.err } // A GoAwayFrame is a GOAWAY frame type GoAwayFrame struct { StreamID quic.StreamID } func (f *GoAwayFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("goaway")) h.WriteToken(jsontext.String("id")) h.WriteToken(jsontext.Uint(uint64(f.StreamID))) h.WriteToken(jsontext.EndObject) return h.err } type SettingsFrame struct { MaxFieldSectionSize int64 Datagram *bool ExtendedConnect *bool Other map[uint64]uint64 } func (f *SettingsFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("settings")) h.WriteToken(jsontext.String("settings")) h.WriteToken(jsontext.BeginArray) if f.MaxFieldSectionSize >= 0 { h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("name")) h.WriteToken(jsontext.String("settings_max_field_section_size")) h.WriteToken(jsontext.String("value")) h.WriteToken(jsontext.Uint(uint64(f.MaxFieldSectionSize))) h.WriteToken(jsontext.EndObject) } if f.Datagram != nil { h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("name")) h.WriteToken(jsontext.String("settings_h3_datagram")) h.WriteToken(jsontext.String("value")) h.WriteToken(jsontext.Bool(*f.Datagram)) h.WriteToken(jsontext.EndObject) } if f.ExtendedConnect != nil { h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("name")) h.WriteToken(jsontext.String("settings_enable_connect_protocol")) h.WriteToken(jsontext.String("value")) h.WriteToken(jsontext.Bool(*f.ExtendedConnect)) h.WriteToken(jsontext.EndObject) } if len(f.Other) > 0 { for k, v := range f.Other { h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("name")) h.WriteToken(jsontext.String("unknown")) h.WriteToken(jsontext.String("name_bytes")) h.WriteToken(jsontext.Uint(k)) h.WriteToken(jsontext.String("value")) h.WriteToken(jsontext.Uint(v)) h.WriteToken(jsontext.EndObject) } } h.WriteToken(jsontext.EndArray) h.WriteToken(jsontext.EndObject) return h.err } // A PushPromiseFrame is a PUSH_PROMISE frame type PushPromiseFrame struct{} func (f *PushPromiseFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("push_promise")) h.WriteToken(jsontext.EndObject) return h.err } // A CancelPushFrame is a CANCEL_PUSH frame type CancelPushFrame struct{} func (f *CancelPushFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("cancel_push")) h.WriteToken(jsontext.EndObject) return h.err } // A MaxPushIDFrame is a MAX_PUSH_ID frame type MaxPushIDFrame struct{} func (f *MaxPushIDFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("max_push_id")) h.WriteToken(jsontext.EndObject) return h.err } // A ReservedFrame is one of the reserved frame types type ReservedFrame struct { Type uint64 } func (f *ReservedFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("reserved")) h.WriteToken(jsontext.String("frame_type_bytes")) h.WriteToken(jsontext.Uint(f.Type)) h.WriteToken(jsontext.EndObject) return h.err } // An UnknownFrame is an unknown frame type type UnknownFrame struct { Type uint64 } func (f *UnknownFrame) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("unknown")) h.WriteToken(jsontext.String("frame_type_bytes")) h.WriteToken(jsontext.Uint(f.Type)) h.WriteToken(jsontext.EndObject) return h.err } quic-go-0.59.0/http3/qlog/frame_test.go000066400000000000000000000112471513066070600176460ustar00rootroot00000000000000package qlog import ( "bytes" "encoding/json" "testing" "github.com/quic-go/quic-go/qlogwriter/jsontext" "github.com/stretchr/testify/require" ) func check(t *testing.T, f any, expected map[string]any) { t.Helper() var buf bytes.Buffer enc := jsontext.NewEncoder(&buf) require.NoError(t, (Frame{Frame: f}).encode(enc)) data := buf.Bytes() require.True(t, json.Valid(data), "invalid JSON: %s", string(data)) checkEncoding(t, data, expected) } func checkEncoding(t *testing.T, data []byte, expected map[string]any) { t.Helper() m := make(map[string]any) require.NoError(t, json.Unmarshal(data, &m)) require.Len(t, m, len(expected)) for key, value := range expected { switch v := value.(type) { case bool, string, map[string]any: require.Equal(t, v, m[key]) case int: require.Equal(t, float64(v), m[key]) case float64: require.Equal(t, v, m[key]) case []map[string]any: // used for header fields require.Contains(t, m, key) slice, ok := m[key].([]any) require.True(t, ok) require.Len(t, slice, len(v)) for i, expectedField := range v { field, ok := slice[i].(map[string]any) require.True(t, ok) require.Equal(t, expectedField, field) } default: t.Fatalf("unexpected type: %T", v) } } } func TestDataFrame(t *testing.T) { check(t, DataFrame{}, map[string]any{ "frame_type": "data", }) } func TestHeadersFrame(t *testing.T) { check(t, HeadersFrame{ HeaderFields: []HeaderField{ {Name: ":status", Value: "200"}, {Name: "content-type", Value: "application/json"}, }, }, map[string]any{ "frame_type": "headers", "header_fields": []map[string]any{ {"name": ":status", "value": "200"}, {"name": "content-type", "value": "application/json"}, }, }) } func TestGoAwayFrame(t *testing.T) { check(t, GoAwayFrame{StreamID: 1337}, map[string]any{ "frame_type": "goaway", "id": 1337, }) } func pointer[T any](v T) *T { return &v } func TestSettingsFrame(t *testing.T) { tests := []struct { name string frame SettingsFrame expected map[string]any }{ { name: "datagram: true", frame: SettingsFrame{ MaxFieldSectionSize: -1, Datagram: pointer(true), }, expected: map[string]any{ "frame_type": "settings", "settings": []map[string]any{{ "name": "settings_h3_datagram", "value": true, }}, }, }, { name: "extended_connect: false", frame: SettingsFrame{ MaxFieldSectionSize: -1, ExtendedConnect: pointer(false), }, expected: map[string]any{ "frame_type": "settings", "settings": []map[string]any{{ "name": "settings_enable_connect_protocol", "value": false, }}, }, }, { name: "max_field_section_size", frame: SettingsFrame{MaxFieldSectionSize: 1337}, expected: map[string]any{ "frame_type": "settings", "settings": []map[string]any{{ "name": "settings_max_field_section_size", "value": float64(1337), }}, }, }, { name: "datagram: false, extended_connect: false", frame: SettingsFrame{ MaxFieldSectionSize: -1, Datagram: pointer(false), ExtendedConnect: pointer(false), }, expected: map[string]any{ "frame_type": "settings", "settings": []map[string]any{ {"name": "settings_h3_datagram", "value": false}, {"name": "settings_enable_connect_protocol", "value": false}, }, }, }, { name: "unknowns", // Only test a single unknown setting. // Testing multiple unknown settings doesn't add a lot of value, // and would require us to deal with non-deterministic map iteration order. frame: SettingsFrame{ MaxFieldSectionSize: -1, Other: map[uint64]uint64{0xdead: 0xbeef}, }, expected: map[string]any{ "frame_type": "settings", "settings": []map[string]any{{ "name": "unknown", "name_bytes": float64(0xdead), "value": float64(0xbeef), }}, }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { check(t, tc.frame, tc.expected) }) } } func TestPushPromiseFrame(t *testing.T) { check(t, PushPromiseFrame{}, map[string]any{ "frame_type": "push_promise", }) } func TestCancelPushFrame(t *testing.T) { check(t, CancelPushFrame{}, map[string]any{ "frame_type": "cancel_push", }) } func TestMaxPushIDFrame(t *testing.T) { check(t, MaxPushIDFrame{}, map[string]any{ "frame_type": "max_push_id", }) } func TestReservedFrame(t *testing.T) { check(t, ReservedFrame{Type: 0x1f}, map[string]any{ "frame_type": "reserved", "frame_type_bytes": 0x1f, }) } func TestUnknownFrame(t *testing.T) { check(t, UnknownFrame{Type: 0x2a}, map[string]any{ "frame_type": "unknown", "frame_type_bytes": 0x2a, }) } quic-go-0.59.0/http3/qlog/qlog_dir.go000066400000000000000000000006541513066070600173150ustar00rootroot00000000000000package qlog import ( "context" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) const EventSchema = "urn:ietf:params:qlog:events:http3-12" func DefaultConnectionTracer(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { return qlog.DefaultConnectionTracerWithSchemas(ctx, isClient, connID, []string{qlog.EventSchema, EventSchema}) } quic-go-0.59.0/http3/qlog/qlog_dir_test.go000066400000000000000000000017201513066070600203470ustar00rootroot00000000000000package qlog import ( "context" "os" "path/filepath" "testing" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/qlog" "github.com/stretchr/testify/require" ) func TestQLOGDIRSet(t *testing.T) { tmpDir := t.TempDir() connID := quic.ConnectionIDFromBytes([]byte{1, 2, 3, 4}) qlogDir := filepath.Join(tmpDir, "qlogs") t.Setenv("QLOGDIR", qlogDir) tracer := DefaultConnectionTracer(context.Background(), true, connID) require.NotNil(t, tracer) // adding and closing a producer makes the tracer close the file recorder := tracer.AddProducer() recorder.Close() _, err := os.Stat(qlogDir) qlogDirCreated := !os.IsNotExist(err) require.True(t, qlogDirCreated) entries, err := os.ReadDir(qlogDir) require.NoError(t, err) require.Len(t, entries, 1) data, err := os.ReadFile(filepath.Join(qlogDir, entries[0].Name())) require.NoError(t, err) require.Contains(t, string(data), EventSchema) require.Contains(t, string(data), qlog.EventSchema) } quic-go-0.59.0/http3/request_writer.go000066400000000000000000000224301513066070600176330ustar00rootroot00000000000000package http3 import ( "bytes" "errors" "fmt" "io" "net" "net/http" "net/http/httptrace" "strconv" "strings" "sync" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "golang.org/x/net/idna" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" ) const bodyCopyBufferSize = 8 * 1024 type requestWriter struct { mutex sync.Mutex encoder *qpack.Encoder headerBuf *bytes.Buffer } func newRequestWriter() *requestWriter { headerBuf := &bytes.Buffer{} encoder := qpack.NewEncoder(headerBuf) return &requestWriter{ encoder: encoder, headerBuf: headerBuf, } } func (w *requestWriter) WriteRequestHeader(wr io.Writer, req *http.Request, gzip bool, streamID quic.StreamID, qlogger qlogwriter.Recorder) error { buf := &bytes.Buffer{} if err := w.writeHeaders(buf, req, gzip, streamID, qlogger); err != nil { return err } if _, err := wr.Write(buf.Bytes()); err != nil { return err } trace := httptrace.ContextClientTrace(req.Context()) traceWroteHeaders(trace) return nil } func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, streamID quic.StreamID, qlogger qlogwriter.Recorder) error { w.mutex.Lock() defer w.mutex.Unlock() defer w.encoder.Close() defer w.headerBuf.Reset() var trailers string if len(req.Trailer) > 0 { keys := make([]string, 0, len(req.Trailer)) for k := range req.Trailer { if httpguts.ValidTrailerHeader(k) { keys = append(keys, k) } } trailers = strings.Join(keys, ", ") } headerFields, err := w.encodeHeaders(req, gzip, trailers, actualContentLength(req), qlogger != nil) if err != nil { return err } b := make([]byte, 0, 128) b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b) if qlogger != nil { qlogCreatedHeadersFrame(qlogger, streamID, len(b)+w.headerBuf.Len(), w.headerBuf.Len(), headerFields) } if _, err := wr.Write(b); err != nil { return err } _, err = wr.Write(w.headerBuf.Bytes()) return err } func isExtendedConnectRequest(req *http.Request) bool { return req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" } // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, // we do respect the Proto field if the method is CONNECT. // // The returned header fields are only set if doQlog is true. func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64, doQlog bool) ([]qlog.HeaderField, error) { host := req.Host if host == "" { host = req.URL.Host } host, err := httpguts.PunycodeHostPort(host) if err != nil { return nil, err } if !httpguts.ValidHostHeader(host) { return nil, errors.New("http3: invalid Host header") } // http.NewRequest sets this field to HTTP/1.1 isExtendedConnect := isExtendedConnectRequest(req) var path string if req.Method != http.MethodConnect || isExtendedConnect { path = req.URL.RequestURI() if !validPseudoPath(path) { orig := path path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) if !validPseudoPath(path) { if req.URL.Opaque != "" { return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) } else { return nil, fmt.Errorf("invalid request :path %q", orig) } } } } // Check for any invalid headers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("invalid HTTP header name %q", k) } for _, v := range vv { if !httpguts.ValidHeaderFieldValue(v) { return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) } } } enumerateHeaders := func(f func(name, value string)) { // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character // followed by the query production (see Sections 3.3 and 3.4 of // [RFC3986]). f(":authority", host) f(":method", req.Method) if req.Method != http.MethodConnect || isExtendedConnect { f(":path", path) f(":scheme", req.URL.Scheme) } if isExtendedConnect { f(":protocol", req.Proto) } if trailers != "" { f("trailer", trailers) } var didUA bool for k, vv := range req.Header { if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { // Host is :authority, already sent. // Content-Length is automatic, set below. continue } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || strings.EqualFold(k, "keep-alive") { // Per 8.1.2.2 Connection-Specific Header // Fields, don't send connection-specific // fields. We have already checked if any // are error-worthy so just ignore the rest. continue } else if strings.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one // User-Agent. If set to nil or empty string, // then omit it. Otherwise if not mentioned, // include the default (below). didUA = true if len(vv) < 1 { continue } vv = vv[:1] if vv[0] == "" { continue } } for _, v := range vv { f(k, v) } } if shouldSendReqContentLength(req.Method, contentLength) { f("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { f("accept-encoding", "gzip") } if !didUA { f("user-agent", defaultUserAgent) } } // Do a first pass over the headers counting bytes to ensure // we don't exceed cc.peerMaxHeaderListSize. This is done as a // separate pass before encoding the headers to prevent // modifying the hpack state. hlSize := uint64(0) enumerateHeaders(func(name, value string) { hf := hpack.HeaderField{Name: name, Value: value} hlSize += uint64(hf.Size()) }) // TODO: check maximum header list size // if hlSize > cc.peerMaxHeaderListSize { // return errRequestHeaderListSize // } trace := httptrace.ContextClientTrace(req.Context()) traceHeaders := traceHasWroteHeaderField(trace) // Header list size is ok. Write the headers. var headerFields []qlog.HeaderField if doQlog { headerFields = make([]qlog.HeaderField, 0, len(req.Header)) } enumerateHeaders(func(name, value string) { name = strings.ToLower(name) w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value}) if traceHeaders { traceWroteHeaderField(trace, name, value) } if doQlog { headerFields = append(headerFields, qlog.HeaderField{Name: name, Value: value}) } }) return headerFields, nil } // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. func authorityAddr(authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port port = "443" host = authority } if a, err := idna.ToASCII(host); err == nil { host = a } // IPv6 address literal, without a port: if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { return host + ":" + port } return net.JoinHostPort(host, port) } // validPseudoPath reports whether v is a valid :path pseudo-header // value. It must be either: // // *) a non-empty string starting with '/' // *) the string '*', for OPTIONS requests. // // For now this is only used a quick check for deciding when to clean // up Opaque URLs before sending requests from the Transport. // See golang.org/issue/16847 // // We used to enforce that the path also didn't start with "//", but // Google's GFE accepts such paths and Chrome sends them, so ignore // that part of the spec. See golang.org/issue/19103. func validPseudoPath(v string) bool { return (len(v) > 0 && v[0] == '/') || v == "*" } // actualContentLength returns a sanitized version of // req.ContentLength, where 0 actually means zero (not unknown) and -1 // means unknown. func actualContentLength(req *http.Request) int64 { if req.Body == nil { return 0 } if req.ContentLength != 0 { return req.ContentLength } return -1 } // shouldSendReqContentLength reports whether the http2.Transport should send // a "content-length" request header. This logic is basically a copy of the net/http // transferWriter.shouldSendContentLength. // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). // -1 means unknown. func shouldSendReqContentLength(method string, contentLength int64) bool { if contentLength > 0 { return true } if contentLength < 0 { return false } // For zero bodies, whether we send a content-length depends on the method. // It also kinda doesn't matter for http2 either way, with END_STREAM. switch method { case "POST", "PUT", "PATCH": return true default: return false } } // WriteRequestTrailer writes HTTP trailers to the stream. // It should be called after the request body has been fully written. func (w *requestWriter) WriteRequestTrailer(wr io.Writer, req *http.Request, streamID quic.StreamID, qlogger qlogwriter.Recorder) error { _, err := writeTrailers(wr, req.Trailer, streamID, qlogger) return err } quic-go-0.59.0/http3/request_writer_test.go000066400000000000000000000121631513066070600206740ustar00rootroot00000000000000package http3 import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func decodeRequest(t *testing.T, str io.Reader, streamID quic.StreamID, eventRecorder *events.Recorder) map[string]string { t.Helper() r := io.LimitedReader{R: str, N: 1000} fp := frameParser{r: &r} frame, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &headersFrame{}, frame) headersFrame := frame.(*headersFrame) data := make([]byte, headersFrame.Length) _, err = io.ReadFull(&r, data) require.NoError(t, err) hfs := decodeQpackHeaderFields(t, data) values := make(map[string]string) for _, hf := range hfs { values[hf.Name] = hf.Value } headerFields := make([]qlog.HeaderField, len(hfs)) for i, hf := range hfs { headerFields[i] = qlog.HeaderField{Name: hf.Name, Value: hf.Value} } require.Equal(t, []qlogwriter.Event{ qlog.FrameCreated{ StreamID: streamID, Raw: qlog.RawInfo{ Length: int(1000 - r.N), PayloadLength: int(headersFrame.Length), }, Frame: qlog.Frame{Frame: qlog.HeadersFrame{HeaderFields: headerFields}}, }, }, eventRecorder.Events(qlog.FrameCreated{}), ) return values } func TestRequestWriterGetRequestGzip(t *testing.T) { t.Run("gzip", func(t *testing.T) { testRequestWriterGzip(t, true) }) t.Run("no gzip", func(t *testing.T) { testRequestWriterGzip(t, false) }) } func testRequestWriterGzip(t *testing.T, gzip bool) { req := httptest.NewRequest(http.MethodGet, "https://quic-go.net/index.html?foo=bar", nil) req.AddCookie(&http.Cookie{Name: "foo", Value: "bar"}) req.AddCookie(&http.Cookie{Name: "baz", Value: "lorem ipsum"}) rw := newRequestWriter() var eventRecorder events.Recorder buf := &bytes.Buffer{} require.NoError(t, rw.WriteRequestHeader(buf, req, gzip, 42, &eventRecorder)) headerFields := decodeRequest(t, buf, 42, &eventRecorder) require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, http.MethodGet, headerFields[":method"]) require.Equal(t, "/index.html?foo=bar", headerFields[":path"]) require.Equal(t, "https", headerFields[":scheme"]) require.Equal(t, `foo=bar; baz="lorem ipsum"`, headerFields["cookie"]) switch gzip { case true: require.Equal(t, "gzip", headerFields["accept-encoding"]) case false: require.NotContains(t, headerFields, "accept-encoding") } } func TestRequestWriterInvalidHostHeader(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "https://quic-go.net/index.html?foo=bar", nil) req.Host = "foo@bar" // @ is invalid rw := newRequestWriter() require.EqualError(t, rw.WriteRequestHeader(&bytes.Buffer{}, req, false, 0, nil), "http3: invalid Host header", ) } func TestRequestWriterConnect(t *testing.T) { // httptest.NewRequest does not properly support the CONNECT method req, err := http.NewRequest(http.MethodConnect, "https://quic-go.net/", nil) require.NoError(t, err) rw := newRequestWriter() buf := &bytes.Buffer{} var eventRecorder events.Recorder require.NoError(t, rw.WriteRequestHeader(buf, req, false, 1337, &eventRecorder)) headerFields := decodeRequest(t, buf, 1337, &eventRecorder) require.Equal(t, http.MethodConnect, headerFields[":method"]) require.Equal(t, "quic-go.net", headerFields[":authority"]) require.NotContains(t, headerFields, ":path") require.NotContains(t, headerFields, ":scheme") require.NotContains(t, headerFields, ":protocol") } func TestRequestWriterExtendedConnect(t *testing.T) { // httptest.NewRequest does not properly support the CONNECT method req, err := http.NewRequest(http.MethodConnect, "https://quic-go.net/", nil) require.NoError(t, err) req.Proto = "webtransport" rw := newRequestWriter() buf := &bytes.Buffer{} var eventRecorder events.Recorder require.NoError(t, rw.WriteRequestHeader(buf, req, false, 1234, &eventRecorder)) headerFields := decodeRequest(t, buf, 1234, &eventRecorder) require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, http.MethodConnect, headerFields[":method"]) require.Equal(t, "/", headerFields[":path"]) require.Equal(t, "https", headerFields[":scheme"]) require.Equal(t, "webtransport", headerFields[":protocol"]) } func TestRequestWriterTrailers(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "https://quic-go.net/upload", nil) req.Trailer = http.Header{ "Trailer1": []string{"foo"}, "Trailer2": []string{"bar"}, "Content-Length": []string{"42"}, // Content-Length is not a valid trailer } rw := newRequestWriter() buf := &bytes.Buffer{} require.NoError(t, rw.WriteRequestHeader(buf, req, false, 42, nil)) headers := decodeHeader(t, buf) require.Len(t, headers["trailer"], 1) require.Contains(t, headers["trailer"][0], "Trailer1") require.Contains(t, headers["trailer"][0], "Trailer2") require.NotContains(t, headers["trailer"][0], "Content-Length") require.NoError(t, rw.WriteRequestTrailer(buf, req, 42, nil)) trailers := decodeHeader(t, buf) require.Equal(t, map[string][]string{ "trailer1": {"foo"}, "trailer2": {"bar"}, }, trailers) } quic-go-0.59.0/http3/response_writer.go000066400000000000000000000230721513066070600200040ustar00rootroot00000000000000package http3 import ( "bytes" "fmt" "log/slog" "net/http" "net/textproto" "strconv" "strings" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go/http3/qlog" "golang.org/x/net/http/httpguts" ) // The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by the http.ResponseWriter. // When a stream is taken over, it's the caller's responsibility to close the stream. type HTTPStreamer interface { HTTPStream() *Stream } const maxSmallResponseSize = 4096 type responseWriter struct { str *Stream conn *rawConn header http.Header trailers map[string]struct{} buf []byte status int // status code passed to WriteHeader // for responses smaller than maxSmallResponseSize, we buffer calls to Write, // and automatically add the Content-Length header smallResponseBuf []byte contentLen int64 // if handler set valid Content-Length header numWritten int64 // bytes written headerComplete bool // set once WriteHeader is called with a status code >= 200 headerWritten bool // set once the response header has been serialized to the stream isHead bool trailerWritten bool // set once the response trailers has been serialized to the stream hijacked bool // set on HTTPStream is called logger *slog.Logger } var ( _ http.ResponseWriter = &responseWriter{} _ http.Flusher = &responseWriter{} _ Settingser = &responseWriter{} _ HTTPStreamer = &responseWriter{} // make sure that we implement (some of the) methods used by the http.ResponseController _ interface { SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error Flush() FlushError() error } = &responseWriter{} ) func newResponseWriter(str *Stream, conn *rawConn, isHead bool, logger *slog.Logger) *responseWriter { return &responseWriter{ str: str, conn: conn, header: http.Header{}, buf: make([]byte, frameHeaderLen), isHead: isHead, logger: logger, } } func (w *responseWriter) Header() http.Header { return w.header } func (w *responseWriter) WriteHeader(status int) { if w.headerComplete { return } // http status must be 3 digits if status < 100 || status > 999 { panic(fmt.Sprintf("invalid WriteHeader code %v", status)) } w.status = status // immediately write 1xx headers if status < 200 { w.writeHeader(status) return } // We're done with headers once we write a status >= 200. w.headerComplete = true // Add Date header. // This is what the standard library does. // Can be disabled by setting the Date header to nil. if _, ok := w.header["Date"]; !ok { w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) } // Content-Length checking // use ParseUint instead of ParseInt, as negative values are invalid if clen := w.header.Get("Content-Length"); clen != "" { if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { w.contentLen = int64(cl) } else { // emit a warning for malformed Content-Length and remove it logger := w.logger if logger == nil { logger = slog.Default() } logger.Error("Malformed Content-Length", "value", clen) w.header.Del("Content-Length") } } } func (w *responseWriter) sniffContentType(p []byte) { // If no content type, apply sniffing algorithm to body. // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing. _, haveType := w.header["Content-Type"] // If the Content-Encoding was set and is non-blank, we shouldn't sniff the body. hasCE := w.header.Get("Content-Encoding") != "" if !hasCE && !haveType && len(p) > 0 { w.header.Set("Content-Type", http.DetectContentType(p)) } } func (w *responseWriter) Write(p []byte) (int, error) { bodyAllowed := bodyAllowedForStatus(w.status) if !w.headerComplete { w.sniffContentType(p) w.WriteHeader(http.StatusOK) bodyAllowed = true } if !bodyAllowed { return 0, http.ErrBodyNotAllowed } w.numWritten += int64(len(p)) if w.contentLen != 0 && w.numWritten > w.contentLen { return 0, http.ErrContentLength } if w.isHead { return len(p), nil } if !w.headerWritten { // Buffer small responses. // This allows us to automatically set the Content-Length field. if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize { w.smallResponseBuf = append(w.smallResponseBuf, p...) return len(p), nil } } return w.doWrite(p) } func (w *responseWriter) doWrite(p []byte) (int, error) { if !w.headerWritten { w.sniffContentType(w.smallResponseBuf) if err := w.writeHeader(w.status); err != nil { return 0, maybeReplaceError(err) } w.headerWritten = true } l := uint64(len(w.smallResponseBuf) + len(p)) if l == 0 { return 0, nil } df := &dataFrame{Length: l} w.buf = w.buf[:0] w.buf = df.Append(w.buf) if w.str.qlogger != nil { w.str.qlogger.RecordEvent(qlog.FrameCreated{ StreamID: w.str.StreamID(), Raw: qlog.RawInfo{Length: len(w.buf) + int(l), PayloadLength: int(l)}, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }) } if _, err := w.str.writeUnframed(w.buf); err != nil { return 0, maybeReplaceError(err) } if len(w.smallResponseBuf) > 0 { if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil { return 0, maybeReplaceError(err) } w.smallResponseBuf = nil } var n int if len(p) > 0 { var err error n, err = w.str.writeUnframed(p) if err != nil { return n, maybeReplaceError(err) } } return n, nil } func (w *responseWriter) writeHeader(status int) error { var headerFields []qlog.HeaderField // only used for qlog var headers bytes.Buffer enc := qpack.NewEncoder(&headers) if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil { return err } if w.str.qlogger != nil { headerFields = append(headerFields, qlog.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) } // Handle trailer fields if vals, ok := w.header["Trailer"]; ok { for _, val := range vals { for _, trailer := range strings.Split(val, ",") { // We need to convert to the canonical header key value here because this will be called when using // headers.Add or headers.Set. trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer)) w.declareTrailer(trailer) } } } for k, v := range w.header { if _, excluded := w.trailers[k]; excluded { continue } // Ignore "Trailer:" prefixed headers if strings.HasPrefix(k, http.TrailerPrefix) { continue } for index := range v { name := strings.ToLower(k) value := v[index] if err := enc.WriteField(qpack.HeaderField{Name: name, Value: value}); err != nil { return err } if w.str.qlogger != nil { headerFields = append(headerFields, qlog.HeaderField{Name: name, Value: value}) } } } buf := make([]byte, 0, frameHeaderLen+headers.Len()) buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) buf = append(buf, headers.Bytes()...) if w.str.qlogger != nil { qlogCreatedHeadersFrame(w.str.qlogger, w.str.StreamID(), len(buf), headers.Len(), headerFields) } _, err := w.str.writeUnframed(buf) return err } func (w *responseWriter) FlushError() error { if !w.headerComplete { w.WriteHeader(http.StatusOK) } _, err := w.doWrite(nil) return err } func (w *responseWriter) flushTrailers() { if w.trailerWritten { return } if err := w.writeTrailers(); err != nil { w.logger.Debug("could not write trailers", "error", err) } } func (w *responseWriter) Flush() { if err := w.FlushError(); err != nil { if w.logger != nil { w.logger.Debug("could not flush to stream", "error", err) } } } // declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a // valid name. func (w *responseWriter) declareTrailer(k string) { if !httpguts.ValidTrailerHeader(k) { // Forbidden by RFC 9110, section 6.5.1. w.logger.Debug("ignoring invalid trailer", slog.String("header", k)) return } if w.trailers == nil { w.trailers = make(map[string]struct{}) } w.trailers[k] = struct{}{} } // writeTrailers will write trailers to the stream if there are any. func (w *responseWriter) writeTrailers() error { // promote headers added via "Trailer:" convention as trailers, these can be added after // streaming the status/headers have been written. for k := range w.header { if strings.HasPrefix(k, http.TrailerPrefix) { w.declareTrailer(k) } } if len(w.trailers) == 0 { return nil } trailers := make(http.Header, len(w.trailers)) for trailer := range w.trailers { if vals, ok := w.header[trailer]; ok { trailers[strings.TrimPrefix(trailer, http.TrailerPrefix)] = vals } } written, err := writeTrailers(w.str.datagramStream, trailers, w.str.StreamID(), w.str.qlogger) if written { w.trailerWritten = true } return err } func (w *responseWriter) HTTPStream() *Stream { w.hijacked = true w.Flush() return w.str } func (w *responseWriter) wasStreamHijacked() bool { return w.hijacked } func (w *responseWriter) ReceivedSettings() <-chan struct{} { return w.conn.ReceivedSettings() } func (w *responseWriter) Settings() *Settings { return w.conn.Settings() } func (w *responseWriter) SetReadDeadline(deadline time.Time) error { return w.str.SetReadDeadline(deadline) } func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { return w.str.SetWriteDeadline(deadline) } // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: return false case status == http.StatusNoContent: return false case status == http.StatusNotModified: return false } return true } quic-go-0.59.0/http3/response_writer_test.go000066400000000000000000000201671513066070600210450ustar00rootroot00000000000000package http3 import ( "bytes" "io" "log/slog" "net/http" "testing" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type testResponseWriter struct { *responseWriter eventRecorder *events.Recorder buf *bytes.Buffer } func (rw *testResponseWriter) DecodeHeaders(t *testing.T, idx int) map[string][]string { t.Helper() rw.Flush() rw.flushTrailers() startLen := rw.buf.Len() frame, err := (&frameParser{r: rw.buf}).ParseNext(nil) require.NoError(t, err) require.IsType(t, &headersFrame{}, frame) payloadLen := frame.(*headersFrame).Length data := make([]byte, payloadLen) headerFrameLen := startLen - rw.buf.Len() + len(data) _, err = io.ReadFull(rw.buf, data) require.NoError(t, err) hfs := decodeQpackHeaderFields(t, data) // check that the decoded header fields are properly logged require.GreaterOrEqual(t, len(rw.eventRecorder.Events(qlog.FrameCreated{})), idx+1) require.IsType(t, qlog.HeadersFrame{}, rw.eventRecorder.Events()[idx].(qlog.FrameCreated).Frame.Frame) ev := rw.eventRecorder.Events()[idx].(qlog.FrameCreated) assert.Equal(t, quic.StreamID(42), ev.StreamID) assert.Equal(t, headerFrameLen, ev.Raw.Length, "raw.Length") assert.Equal(t, int(payloadLen), ev.Raw.PayloadLength, "raw.PayloadLength") fields := make(map[string][]string) for _, p := range hfs { fields[p.Name] = append(fields[p.Name], p.Value) require.Contains(t, ev.Frame.Frame.(qlog.HeadersFrame).HeaderFields, qlog.HeaderField{Name: p.Name, Value: p.Value}, ) } return fields } func (rw *testResponseWriter) DecodeBody(t *testing.T) []byte { t.Helper() frame, err := (&frameParser{r: rw.buf}).ParseNext(nil) if err == io.EOF { return nil } require.NoError(t, err) require.IsType(t, &dataFrame{}, frame) body := make([]byte, frame.(*dataFrame).Length) _, err = io.ReadFull(rw.buf, body) require.NoError(t, err) return body } func newTestResponseWriter(t *testing.T) *testResponseWriter { var eventRecorder events.Recorder buf := &bytes.Buffer{} mockCtrl := gomock.NewController(t) str := NewMockDatagramStream(mockCtrl) str.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() rw := newResponseWriter( newStream(str, nil, nil, func(io.Reader, *headersFrame) error { return nil }, &eventRecorder), nil, false, slog.Default(), ) return &testResponseWriter{ responseWriter: rw, eventRecorder: &eventRecorder, buf: buf, } } func TestResponseWriterInvalidStatus(t *testing.T) { rw := newTestResponseWriter(t) require.Panics(t, func() { rw.WriteHeader(99) }) require.Panics(t, func() { rw.WriteHeader(1000) }) } func TestResponseWriterHeader(t *testing.T) { rw := newTestResponseWriter(t) rw.Header().Add("Content-Length", "42") rw.WriteHeader(http.StatusTeapot) // 418 // repeated WriteHeader calls are ignored rw.WriteHeader(http.StatusInternalServerError) // set cookies http.SetCookie(rw, &http.Cookie{Name: "foo", Value: "bar"}) http.SetCookie(rw, &http.Cookie{Name: "baz", Value: "lorem ipsum"}) // write some data rw.Write([]byte("foobar")) fields := rw.DecodeHeaders(t, 0) require.Equal(t, []string{"418"}, fields[":status"]) require.Equal(t, []string{"42"}, fields["content-length"]) require.Equal(t, []string{"foo=bar", `baz="lorem ipsum"`}, fields["set-cookie"], ) require.Equal(t, []byte("foobar"), rw.DecodeBody(t)) } func TestResponseWriterDataWithoutHeader(t *testing.T) { rw := newTestResponseWriter(t) rw.Write([]byte("foobar")) fields := rw.DecodeHeaders(t, 0) require.Equal(t, []string{"200"}, fields[":status"]) require.Equal(t, []byte("foobar"), rw.DecodeBody(t)) } func TestResponseWriterDataStatusWithoutBody(t *testing.T) { rw := newTestResponseWriter(t) rw.WriteHeader(http.StatusNotModified) n, err := rw.Write([]byte("foobar")) require.Zero(t, n) require.ErrorIs(t, err, http.ErrBodyNotAllowed) fields := rw.DecodeHeaders(t, 0) require.Equal(t, []string{"304"}, fields[":status"]) require.Empty(t, rw.DecodeBody(t)) } func TestResponseWriterContentLength(t *testing.T) { rw := newTestResponseWriter(t) rw.Header().Set("Content-Length", "6") n, err := rw.Write([]byte("foobar")) require.Equal(t, 6, n) require.NoError(t, err) n, err = rw.Write([]byte{0x42}) require.Zero(t, n) require.ErrorIs(t, err, http.ErrContentLength) fields := rw.DecodeHeaders(t, 0) require.Equal(t, []string{"200"}, fields[":status"]) require.Equal(t, []string{"6"}, fields["content-length"]) require.Equal(t, []byte("foobar"), rw.DecodeBody(t)) } func TestResponseWriterContentTypeSniffing(t *testing.T) { t.Run("no content type", func(t *testing.T) { testContentTypeSniffing(t, map[string]string{}, "text/html; charset=utf-8") }) t.Run("explicit content type", func(t *testing.T) { testContentTypeSniffing(t, map[string]string{"Content-Type": "text/plain"}, "text/plain") }) t.Run("with content encoding", func(t *testing.T) { testContentTypeSniffing(t, map[string]string{"Content-Encoding": "gzip"}, "") }) } func testContentTypeSniffing(t *testing.T, hdrs map[string]string, expectedContentType string) { rw := newTestResponseWriter(t) for k, v := range hdrs { rw.Header().Set(k, v) } rw.Write([]byte("")) fields := rw.DecodeHeaders(t, 0) require.Equal(t, []string{"200"}, fields[":status"]) if expectedContentType == "" { require.NotContains(t, fields, "content-type") } else { require.Equal(t, []string{expectedContentType}, fields["content-type"]) } } func TestResponseWriterEarlyHints(t *testing.T) { rw := newTestResponseWriter(t) rw.Header().Add("Link", "; rel=preload; as=style") rw.Header().Add("Link", "; rel=preload; as=script") rw.WriteHeader(http.StatusEarlyHints) // status 103 n, err := rw.Write([]byte("foobar")) require.Equal(t, 6, n) require.NoError(t, err) // Early Hints must have been received fields := rw.DecodeHeaders(t, 0) require.Equal(t, 2, len(fields)) require.Equal(t, []string{"103"}, fields[":status"]) require.Equal(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, fields["link"], ) // headers sent in the informational response must also be included in the final response fields = rw.DecodeHeaders(t, 1) require.Equal(t, 4, len(fields)) require.Equal(t, []string{"200"}, fields[":status"]) require.Contains(t, fields, "date") require.Contains(t, fields, "content-type") require.Equal(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, fields["link"], ) require.Equal(t, []byte("foobar"), rw.DecodeBody(t)) } func TestResponseWriterTrailers(t *testing.T) { rw := newTestResponseWriter(t) rw.Header().Add("Trailer", "key, Content-Length") // Content-Length is not a valid trailer n, err := rw.Write([]byte("foobar")) require.Equal(t, 6, n) require.NoError(t, err) // writeTrailers needs to be called after writing the full body headers := rw.DecodeHeaders(t, 0) require.Equal(t, []string{"key, Content-Length"}, headers["trailer"]) require.NotContains(t, headers, "foo") require.Equal(t, []byte("foobar"), rw.DecodeBody(t)) // headers set after writing the body are trailers rw.Header().Set("key", "value") // announced trailer rw.Header().Set("foo", "bar") // this trailer was not announced, and will therefore be ignored rw.Header().Set(http.TrailerPrefix+"lorem", "ipsum") // unannounced trailer with trailer prefix rw.Header().Set("Content-Length", "999") // invalid trailer, will be ignored require.NoError(t, rw.writeTrailers()) trailers := rw.DecodeHeaders(t, 2) require.Equal(t, []string{"value"}, trailers["key"]) require.Equal(t, []string{"ipsum"}, trailers["lorem"]) // trailers without the trailer prefix that were not announced are ignored require.NotContains(t, trailers, "foo") // invalid trailers are ignored require.NotContains(t, trailers, "content-length") } quic-go-0.59.0/http3/server.go000066400000000000000000000530751513066070600160660ustar00rootroot00000000000000package http3 import ( "context" "crypto/tls" "errors" "fmt" "io" "log/slog" "net" "net/http" "slices" "strings" "sync" "sync/atomic" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" ) // NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. const NextProtoH3 = "h3" // StreamType is the stream type of a unidirectional stream. type StreamType uint64 const ( streamTypeControlStream = 0 streamTypePushStream = 1 streamTypeQPACKEncoderStream = 2 streamTypeQPACKDecoderStream = 3 ) // A QUICListener listens for incoming QUIC connections. type QUICListener interface { Accept(context.Context) (*quic.Conn, error) Addr() net.Addr io.Closer } var _ QUICListener = &quic.EarlyListener{} // ConfigureTLSConfig creates a new tls.Config which can be used // to create a quic.Listener meant for serving HTTP/3. func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { // Workaround for https://github.com/golang/go/issues/60506. // This initializes the session tickets _before_ cloning the config. _, _ = tlsConf.DecryptTicket(nil, tls.ConnectionState{}) config := tlsConf.Clone() config.NextProtos = []string{NextProtoH3} if gfc := config.GetConfigForClient; gfc != nil { config.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) { conf, err := gfc(ch) if conf == nil || err != nil { return conf, err } return ConfigureTLSConfig(conf), nil } } return config } // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. type contextKey struct { name string } func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name } // ServerContextKey is a context key. It can be used in HTTP // handlers with Context.Value to access the server that // started the handler. The associated value will be of // type *http3.Server. var ServerContextKey = &contextKey{"http3-server"} // RemoteAddrContextKey is a context key. It can be used in // HTTP handlers with Context.Value to access the remote // address of the connection. The associated value will be of // type net.Addr. // // Use this value instead of [http.Request.RemoteAddr] if you // require access to the remote address of the connection rather // than its string representation. var RemoteAddrContextKey = &contextKey{"remote-addr"} // listener contains info about specific listener added with addListener type listener struct { ln *QUICListener port int // 0 means that no info about port is available // if this listener was constructed by the application, it won't be closed when the server is closed createdLocally bool } // Server is a HTTP/3 server. type Server struct { // Addr optionally specifies the UDP address for the server to listen on, // in the form "host:port". // // When used by ListenAndServe and ListenAndServeTLS methods, if empty, // ":https" (port 443) is used. See net.Dial for details of the address // format. // // Otherwise, if Port is not set and underlying QUIC listeners do not // have valid port numbers, the port part is used in Alt-Svc headers set // with SetQUICHeaders. Addr string // Port is used in Alt-Svc response headers set with SetQUICHeaders. If // needed Port can be manually set when the Server is created. // // This is useful when a Layer 4 firewall is redirecting UDP traffic and // clients must use a port different from the port the Server is // listening on. Port int // TLSConfig provides a TLS configuration for use by server. It must be // set for ListenAndServe and Serve methods. TLSConfig *tls.Config // QUICConfig provides the parameters for QUIC connection created with Serve. // If nil, it uses reasonable default values. // // Configured versions are also used in Alt-Svc response header set with SetQUICHeaders. QUICConfig *quic.Config // Handler is the HTTP request handler to use. If not set, defaults to // http.NotFound. Handler http.Handler // EnableDatagrams enables support for HTTP/3 datagrams (RFC 9297). // If set to true, QUICConfig.EnableDatagrams will be set. EnableDatagrams bool // MaxHeaderBytes controls the maximum number of bytes the server will // read parsing the request HEADERS frame. It does not limit the size of // the request body. If zero or negative, http.DefaultMaxHeaderBytes is // used. MaxHeaderBytes int // AdditionalSettings specifies additional HTTP/3 settings. // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 // IdleTimeout specifies how long until idle clients connection should be // closed. Idle refers only to the HTTP/3 layer, activity at the QUIC layer // like PING frames are not considered. // If zero or negative, there is no timeout. IdleTimeout time.Duration // ConnContext optionally specifies a function that modifies the context used for a new connection c. // The provided ctx has a ServerContextKey value. ConnContext func(ctx context.Context, c *quic.Conn) context.Context Logger *slog.Logger mutex sync.RWMutex listeners []listener closed bool closeCtx context.Context // canceled when the server is closed closeCancel context.CancelFunc // cancels the closeCtx graceCtx context.Context // canceled when the server is closed or gracefully closed graceCancel context.CancelFunc // cancels the graceCtx connCount atomic.Int64 connHandlingDone chan struct{} altSvcHeader string } // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. // // If s.Addr is blank, ":https" is used. func (s *Server) ListenAndServe() error { ln, err := s.setupListenerForConn(s.TLSConfig, nil) if err != nil { return err } defer s.removeListener(ln) return s.serveListener(*ln) } // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. // // If s.Addr is blank, ":https" is used. func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { var err error certs := make([]tls.Certificate, 1) certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return err } // We currently only use the cert-related stuff from tls.Config, // so we don't need to make a full copy. ln, err := s.setupListenerForConn(&tls.Config{Certificates: certs}, nil) if err != nil { return err } defer s.removeListener(ln) return s.serveListener(*ln) } // Serve an existing UDP connection. // It is possible to reuse the same connection for outgoing connections. // Closing the server does not close the connection. func (s *Server) Serve(conn net.PacketConn) error { ln, err := s.setupListenerForConn(s.TLSConfig, conn) if err != nil { return err } defer s.removeListener(ln) return s.serveListener(*ln) } // init initializes the contexts used for shutting down the server. // It must be called with the mutex held. func (s *Server) init() { if s.closeCtx == nil { s.closeCtx, s.closeCancel = context.WithCancel(context.Background()) s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx) } s.connHandlingDone = make(chan struct{}, 1) } func (s *Server) decreaseConnCount() { if s.connCount.Add(-1) == 0 && s.graceCtx.Err() != nil { close(s.connHandlingDone) } } // ServeQUICConn serves a single QUIC connection. func (s *Server) ServeQUICConn(conn *quic.Conn) error { s.mutex.Lock() if s.closed { s.mutex.Unlock() return http.ErrServerClosed } s.init() s.mutex.Unlock() s.connCount.Add(1) defer s.decreaseConnCount() return s.handleConn(conn) } // ServeListener serves an existing QUIC listener. // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // and use it to construct a http3-friendly QUIC listener. // Closing the server does not close the listener. It is the application's responsibility to close them. // ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed. func (s *Server) ServeListener(ln QUICListener) error { s.mutex.Lock() if err := s.addListener(&ln, false); err != nil { s.mutex.Unlock() return err } s.mutex.Unlock() defer s.removeListener(&ln) return s.serveListener(ln) } func (s *Server) serveListener(ln QUICListener) error { for { conn, err := ln.Accept(s.graceCtx) // server closed if errors.Is(err, quic.ErrServerClosed) || s.graceCtx.Err() != nil { return http.ErrServerClosed } if err != nil { return err } s.connCount.Add(1) go func() { defer s.decreaseConnCount() if err := s.handleConn(conn); err != nil { if s.Logger != nil { s.Logger.Debug("handling connection failed", "error", err) } } }() } } var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) (*QUICListener, error) { if tlsConf == nil { return nil, errServerWithoutTLSConfig } baseConf := ConfigureTLSConfig(tlsConf) quicConf := s.QUICConfig if quicConf == nil { quicConf = &quic.Config{Allow0RTT: true} } else { quicConf = s.QUICConfig.Clone() } if s.EnableDatagrams { quicConf.EnableDatagrams = true } s.mutex.Lock() defer s.mutex.Unlock() closed := s.closed if closed { return nil, http.ErrServerClosed } var ln QUICListener var err error if conn == nil { addr := s.Addr if addr == "" { addr = ":https" } ln, err = quic.ListenAddrEarly(addr, baseConf, quicConf) } else { ln, err = quic.ListenEarly(conn, baseConf, quicConf) } if err != nil { return nil, err } if err := s.addListener(&ln, true); err != nil { return nil, err } return &ln, nil } func extractPort(addr string) (int, error) { _, portStr, err := net.SplitHostPort(addr) if err != nil { return 0, err } portInt, err := net.LookupPort("tcp", portStr) if err != nil { return 0, err } return portInt, nil } func (s *Server) generateAltSvcHeader() { if len(s.listeners) == 0 { // Don't announce any ports since no one is listening for connections s.altSvcHeader = "" return } // This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed. var altSvc []string addPort := func(port int) { altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, NextProtoH3, port)) } if s.Port != 0 { // if Port is specified, we must use it instead of the // listener addresses since there's a reason it's specified. addPort(s.Port) } else { // if we have some listeners assigned, try to find ports // which we can announce, otherwise nothing should be announced validPortsFound := false for _, info := range s.listeners { if info.port != 0 { addPort(info.port) validPortsFound = true } } if !validPortsFound { if port, err := extractPort(s.Addr); err == nil { addPort(port) } } } s.altSvcHeader = strings.Join(altSvc, ",") } func (s *Server) addListener(l *QUICListener, createdLocally bool) error { if s.closed { return http.ErrServerClosed } s.init() laddr := (*l).Addr() if port, err := extractPort(laddr.String()); err == nil { s.listeners = append(s.listeners, listener{ln: l, port: port, createdLocally: createdLocally}) } else { logger := s.Logger if logger == nil { logger = slog.Default() } logger.Error("Unable to extract port from listener, will not be announced using SetQUICHeaders", "local addr", laddr, "error", err) s.listeners = append(s.listeners, listener{ln: l, port: 0, createdLocally: createdLocally}) } s.generateAltSvcHeader() return nil } func (s *Server) removeListener(l *QUICListener) { s.mutex.Lock() defer s.mutex.Unlock() s.listeners = slices.DeleteFunc(s.listeners, func(info listener) bool { return info.ln == l }) s.generateAltSvcHeader() } func (s *Server) NewRawServerConn(conn *quic.Conn) (*RawServerConn, error) { hconn, _, _, err := s.newRawServerConn(conn) if err != nil { return nil, err } return hconn, nil } func (s *Server) newRawServerConn(conn *quic.Conn) (*RawServerConn, *quic.SendStream, qlogwriter.Recorder, error) { var qlogger qlogwriter.Recorder if qlogTrace := conn.QlogTrace(); qlogTrace != nil && qlogTrace.SupportsSchemas(qlog.EventSchema) { qlogger = qlogTrace.AddProducer() } connCtx := conn.Context() connCtx = context.WithValue(connCtx, ServerContextKey, s) connCtx = context.WithValue(connCtx, http.LocalAddrContextKey, conn.LocalAddr()) connCtx = context.WithValue(connCtx, RemoteAddrContextKey, conn.RemoteAddr()) if s.ConnContext != nil { connCtx = s.ConnContext(connCtx, conn) if connCtx == nil { panic("http3: ConnContext returned nil") } } hconn := newRawServerConn( conn, s.EnableDatagrams, s.IdleTimeout, qlogger, s.Logger, connCtx, s.Handler, s.maxHeaderBytes(), ) // open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later // when the server is gracefully closed ctrlStr, err := hconn.openControlStream(&settingsFrame{ MaxFieldSectionSize: int64(s.maxHeaderBytes()), Datagram: s.EnableDatagrams, ExtendedConnect: true, Other: s.AdditionalSettings, }) if err != nil { return nil, nil, nil, fmt.Errorf("opening the control stream failed: %w", err) } return hconn, ctrlStr, qlogger, nil } // handleConn handles the HTTP/3 exchange on a QUIC connection. // It blocks until all HTTP handlers for all streams have returned. func (s *Server) handleConn(conn *quic.Conn) error { hconn, ctrlStr, qlogger, err := s.newRawServerConn(conn) if err != nil { return err } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for { str, err := conn.AcceptUniStream(context.Background()) if err != nil { return } go hconn.HandleUnidirectionalStream(str) } }() var nextStreamID quic.StreamID var handleErr error var inGracefulShutdown bool // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. ctx := s.graceCtx for { // The context used here is: // * before graceful shutdown: s.graceCtx // * after graceful shutdown: s.closeCtx // This allows us to keep accepting (and resetting) streams after graceful shutdown has started. str, err := conn.AcceptStream(ctx) if err != nil { // the underlying connection was closed (by either side) if conn.Context().Err() != nil { var appErr *quic.ApplicationError if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) { handleErr = fmt.Errorf("accepting stream failed: %w", err) } break } // server (not gracefully) closed, close the connection immediately if s.closeCtx.Err() != nil { hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") handleErr = http.ErrServerClosed break } inGracefulShutdown = s.graceCtx.Err() != nil if !inGracefulShutdown { var appErr *quic.ApplicationError if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) { handleErr = fmt.Errorf("accepting stream failed: %w", err) } break } // gracefully closed, send GOAWAY frame and wait for requests to complete or grace period to end // new requests will be rejected and shouldn't be sent if qlogger != nil { qlogger.RecordEvent(qlog.FrameCreated{ StreamID: ctrlStr.StreamID(), Frame: qlog.Frame{Frame: qlog.GoAwayFrame{StreamID: nextStreamID}}, }) } wg.Add(1) // Send the GOAWAY frame in a separate Goroutine. // Sending might block if the peer didn't grant enough flow control credit. // Write is guaranteed to return once the connection is closed. go func() { defer wg.Done() _, _ = ctrlStr.Write((&goAwayFrame{StreamID: nextStreamID}).Append(nil)) }() ctx = s.closeCtx continue } if inGracefulShutdown { str.CancelRead(quic.StreamErrorCode(ErrCodeRequestRejected)) str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestRejected)) continue } nextStreamID = str.StreamID() + 4 wg.Add(1) go func() { // HandleRequestStream will return once the request has been handled, // or the underlying connection is closed. defer wg.Done() hconn.HandleRequestStream(str) }() } wg.Wait() return handleErr } func (s *Server) maxHeaderBytes() int { if s.MaxHeaderBytes <= 0 { return http.DefaultMaxHeaderBytes } return s.MaxHeaderBytes } // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. // It is the caller's responsibility to close any connection passed to ServeQUICConn. func (s *Server) Close() error { s.mutex.Lock() defer s.mutex.Unlock() s.closed = true // server is never used if s.closeCtx == nil { return nil } s.closeCancel() var err error for _, l := range s.listeners { if l.createdLocally { if cerr := (*l.ln).Close(); cerr != nil && err == nil { err = cerr } } } if s.connCount.Load() == 0 { return err } // wait for all connections to be closed <-s.connHandlingDone return err } // Shutdown gracefully shuts down the server without interrupting any active connections. // The server sends a GOAWAY frame first, then or for all running requests to complete. // Shutdown in combination with ListenAndServe may race if it is called before a UDP socket is established. // It is recommended to use Serve instead. func (s *Server) Shutdown(ctx context.Context) error { s.mutex.Lock() s.closed = true // server was never used if s.closeCtx == nil { s.mutex.Unlock() return nil } s.graceCancel() // close all listeners var closeErrs []error for _, l := range s.listeners { if l.createdLocally { if err := (*l.ln).Close(); err != nil { closeErrs = append(closeErrs, err) } } } s.mutex.Unlock() if len(closeErrs) > 0 { return errors.Join(closeErrs...) } if s.connCount.Load() == 0 { return s.Close() } select { case <-s.connHandlingDone: // all connections were closed // When receiving a GOAWAY frame, HTTP/3 clients are expected to close the connection // once all requests were successfully handled... return s.Close() case <-ctx.Done(): // ... however, clients handling long-lived requests (and misbehaving clients), // might not do so before the context is cancelled. // In this case, we close the server, which closes all existing connections // (expect those passed to ServeQUICConn). _ = s.Close() return ctx.Err() } } // ErrNoAltSvcPort is the error returned by SetQUICHeaders when no port was found // for Alt-Svc to announce. This can happen if listening on a PacketConn without a port // (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr. var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr") // SetQUICHeaders can be used to set the proper headers that announce that this server supports HTTP/3. // The values set by default advertise all the ports the server is listening on, but can be // changed to a specific port by setting Server.Port before launching the server. // If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used // to extract the port, if specified. // For example, a server launched using ListenAndServe on an address with port 443 would set: // // Alt-Svc: h3=":443"; ma=2592000 func (s *Server) SetQUICHeaders(hdr http.Header) error { s.mutex.RLock() defer s.mutex.RUnlock() if s.altSvcHeader == "" { return ErrNoAltSvcPort } // use the map directly to avoid constant canonicalization since the key is already canonicalized hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader) return nil } // ListenAndServeQUIC listens on the UDP network address addr and calls the // handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is // used when handler is nil. func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { server := &Server{ Addr: addr, Handler: handler, } return server.ListenAndServeTLS(certFile, keyFile) } // ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC // connections in parallel. It returns if one of the two returns an error. // http.DefaultServeMux is used when handler is nil. // The correct Alt-Svc headers for QUIC are set. func ListenAndServeTLS(addr, certFile, keyFile string, handler http.Handler) error { // Load certs var err error certs := make([]tls.Certificate, 1) certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return err } // We currently only use the cert-related stuff from tls.Config, // so we don't need to make a full copy. config := &tls.Config{ Certificates: certs, } if addr == "" { addr = ":https" } // Open the listeners udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return err } udpConn, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } defer udpConn.Close() if handler == nil { handler = http.DefaultServeMux } // Start the servers quicServer := &Server{ TLSConfig: config, Handler: handler, } hErr := make(chan error, 1) qErr := make(chan error, 1) go func() { hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { quicServer.SetQUICHeaders(w.Header()) handler.ServeHTTP(w, r) })) }() go func() { qErr <- quicServer.Serve(udpConn) }() select { case err := <-hErr: quicServer.Close() return err case err := <-qErr: // Cannot close the HTTP server or wait for requests to complete properly :/ return err } } quic-go-0.59.0/http3/server_conn.go000066400000000000000000000166551513066070600171060ustar00rootroot00000000000000package http3 import ( "context" "errors" "io" "log/slog" "net/http" "runtime" "strconv" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/qlogwriter" ) // RawServerConn is an HTTP/3 server connection. // It can be used for advanced use cases where the application wants to manage the QUIC connection lifecycle. type RawServerConn struct { rawConn rawConn idleTimeout time.Duration idleTimer *time.Timer serverContext context.Context requestHandler http.Handler maxHeaderBytes int decoder *qpack.Decoder qlogger qlogwriter.Recorder logger *slog.Logger } func newRawServerConn( conn *quic.Conn, enableDatagrams bool, idleTimeout time.Duration, qlogger qlogwriter.Recorder, logger *slog.Logger, serverContext context.Context, requestHandler http.Handler, maxHeaderBytes int, ) *RawServerConn { c := &RawServerConn{ idleTimeout: idleTimeout, serverContext: serverContext, requestHandler: requestHandler, maxHeaderBytes: maxHeaderBytes, decoder: qpack.NewDecoder(), qlogger: qlogger, logger: logger, } c.rawConn = *newRawConn(conn, enableDatagrams, c.onStreamsEmpty, nil, qlogger, logger) if idleTimeout > 0 { c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer) } return c } func (c *RawServerConn) onStreamsEmpty() { if c.idleTimeout > 0 { c.idleTimer.Reset(c.idleTimeout) } } func (c *RawServerConn) onIdleTimer() { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout") } // CloseWithError closes the connection with the given error code and message. func (c *RawServerConn) CloseWithError(code quic.ApplicationErrorCode, msg string) error { if c.idleTimer != nil { c.idleTimer.Stop() } return c.rawConn.CloseWithError(code, msg) } // HandleRequestStream handles an HTTP/3 request on a bidirectional request stream. // The stream can either be obtained by calling AcceptStream on the underlying QUIC connection, // or (internally) by using the server's stream accept loop. func (c *RawServerConn) HandleRequestStream(str *quic.Stream) { hstr := c.rawConn.TrackStream(str) c.handleRequestStream(hstr) } func (c *RawServerConn) requestMaxHeaderBytes() int { if c.maxHeaderBytes <= 0 { return http.DefaultMaxHeaderBytes } return c.maxHeaderBytes } func (c *RawServerConn) openControlStream(settings *settingsFrame) (*quic.SendStream, error) { return c.rawConn.openControlStream(settings) } func (c *RawServerConn) handleRequestStream(str *stateTrackingStream) { if c.idleTimeout > 0 { // This only applies if the stream is the first active stream, // but it's ok to stop a stopped timer. c.idleTimer.Stop() } conn := &c.rawConn qlogger := c.qlogger decoder := c.decoder connCtx := c.serverContext maxHeaderBytes := c.requestMaxHeaderBytes() fp := &frameParser{closeConn: conn.CloseWithError, r: str, streamID: str.StreamID()} frame, err := fp.ParseNext(qlogger) if err != nil { str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) return } hf, ok := frame.(*headersFrame) if !ok { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame") return } if hf.Length > uint64(maxHeaderBytes) { maybeQlogInvalidHeadersFrame(qlogger, str.StreamID(), hf.Length) // stop the client from sending more data str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad)) // send a 431 Response (Request Header Fields Too Large) c.rejectWithHeaderFieldsTooLarge(str) return } headerBlock := make([]byte, hf.Length) if _, err := io.ReadFull(str, headerBlock); err != nil { maybeQlogInvalidHeadersFrame(qlogger, str.StreamID(), hf.Length) str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) return } decodeFn := decoder.Decode(headerBlock) var hfs []qpack.HeaderField if qlogger != nil { hfs = make([]qpack.HeaderField, 0, 16) } req, err := requestFromHeaders(decodeFn, maxHeaderBytes, &hfs) if qlogger != nil { qlogParsedHeadersFrame(qlogger, str.StreamID(), hf, hfs) } if err != nil { if errors.Is(err, errHeaderTooLarge) { // stop the client from sending more data str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad)) // send a 431 Response (Request Header Fields Too Large) c.rejectWithHeaderFieldsTooLarge(str) return } errCode := ErrCodeMessageError var qpackErr *qpackError if errors.As(err, &qpackErr) { errCode = ErrCodeQPACKDecompressionFailed } str.CancelRead(quic.StreamErrorCode(errCode)) str.CancelWrite(quic.StreamErrorCode(errCode)) return } connState := conn.ConnectionState().TLS req.TLS = &connState req.RemoteAddr = conn.RemoteAddr().String() // Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set). // See section 4.1.2 of RFC 9114. contentLength := int64(-1) if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { contentLength = req.ContentLength } hstr := newStream(str, conn, nil, func(r io.Reader, hf *headersFrame) error { trailers, err := decodeTrailers(r, hf, maxHeaderBytes, decoder, qlogger, str.StreamID()) if err != nil { return err } req.Trailer = trailers return nil }, qlogger) body := newRequestBody(hstr, contentLength, connCtx, conn.ReceivedSettings(), conn.Settings) req.Body = body if c.logger != nil { c.logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI) } ctx, cancel := context.WithCancel(connCtx) req = req.WithContext(ctx) context.AfterFunc(str.Context(), cancel) r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, c.logger) handler := c.requestHandler if handler == nil { handler = http.DefaultServeMux } // It's the client's responsibility to decide which requests are eligible for 0-RTT. var panicked bool func() { defer func() { if p := recover(); p != nil { panicked = true if p == http.ErrAbortHandler { return } // Copied from net/http/server.go const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] logger := c.logger if logger == nil { logger = slog.Default() } logger.Error("http3: panic serving", "arg", p, "trace", string(buf)) } }() handler.ServeHTTP(r, req) }() if r.wasStreamHijacked() { return } // abort the stream when there is a panic if panicked { str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) str.CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) return } // response not written to the client yet, set Content-Length if !r.headerWritten { if _, haveCL := r.header["Content-Length"]; !haveCL { r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) } } r.Flush() r.flushTrailers() // If the EOF was read by the handler, CancelRead() is a no-op. str.CancelRead(quic.StreamErrorCode(ErrCodeNoError)) str.Close() } func (c *RawServerConn) rejectWithHeaderFieldsTooLarge(str *stateTrackingStream) { hstr := newStream(str, &c.rawConn, nil, nil, c.qlogger) defer hstr.Close() r := newResponseWriter(hstr, &c.rawConn, false, c.logger) r.WriteHeader(http.StatusRequestHeaderFieldsTooLarge) r.Flush() } // HandleUnidirectionalStream handles an incoming unidirectional stream. func (c *RawServerConn) HandleUnidirectionalStream(str *quic.ReceiveStream) { c.rawConn.handleUnidirectionalStream(str, true) } quic-go-0.59.0/http3/server_test.go000066400000000000000000000553271513066070600171270ustar00rootroot00000000000000package http3 import ( "bytes" "context" "crypto/tls" "fmt" "io" "log/slog" "net" "net/http" "net/http/httptest" "runtime" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/internal/testdata" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func TestConfigureTLSConfig(t *testing.T) { t.Run("basic config", func(t *testing.T) { conf := ConfigureTLSConfig(&tls.Config{}) require.Equal(t, conf.NextProtos, []string{NextProtoH3}) }) t.Run("ALPN set", func(t *testing.T) { conf := ConfigureTLSConfig(&tls.Config{NextProtos: []string{"foo", "bar"}}) require.Equal(t, []string{NextProtoH3}, conf.NextProtos) }) // for configs that define GetConfigForClient, the ALPN is set to h3 t.Run("GetConfigForClient", func(t *testing.T) { staticConf := &tls.Config{NextProtos: []string{"foo", "bar"}} conf := ConfigureTLSConfig(&tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return staticConf, nil }, }) innerConf, err := conf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "example.com"}) require.NoError(t, err) require.NotNil(t, innerConf) require.Equal(t, []string{NextProtoH3}, innerConf.NextProtos) // make sure the original config was not modified require.Equal(t, []string{"foo", "bar"}, staticConf.NextProtos) }) // GetConfigForClient might return a nil tls.Config t.Run("GetConfigForClient returns nil", func(t *testing.T) { conf := ConfigureTLSConfig(&tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }, }) innerConf, err := conf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "example.com"}) require.NoError(t, err) require.Nil(t, innerConf) }) } func TestServerSettings(t *testing.T) { t.Run("enable datagrams", func(t *testing.T) { testServerSettings(t, true, nil) }) t.Run("additional settings", func(t *testing.T) { testServerSettings(t, false, map[uint64]uint64{13: 37}) }) } func testServerSettings(t *testing.T, enableDatagrams bool, other map[uint64]uint64) { s := Server{ EnableDatagrams: enableDatagrams, AdditionalSettings: other, } s.init() testDone := make(chan struct{}) defer close(testDone) clientConn, serverConn := newConnPair(t) go s.handleConn(serverConn) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() settingsStr, err := clientConn.AcceptUniStream(ctx) require.NoError(t, err) settingsStr.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1024) n, err := settingsStr.Read(b) require.NoError(t, err) b = b[:n] typ, l, err := quicvarint.Parse(b) require.NoError(t, err) require.EqualValues(t, streamTypeControlStream, typ) fp := (&frameParser{r: bytes.NewReader(b[l:])}) f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f) settingsFrame := f.(*settingsFrame) // Extended CONNECT is always supported require.True(t, settingsFrame.ExtendedConnect) require.Equal(t, settingsFrame.Datagram, enableDatagrams) require.Equal(t, settingsFrame.Other, other) } func TestServerRequestHandling(t *testing.T) { t.Run("200 with an empty handler", func(t *testing.T) { var eventRecorder events.Recorder hfs, body := testServerRequestHandling(t, func(w http.ResponseWriter, r *http.Request) {}, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil), &eventRecorder, ) require.Equal(t, hfs[":status"], []string{"200"}) require.Empty(t, body) require.Len(t, eventRecorder.Events(qlog.FrameParsed{}), 1) require.IsType(t, qlog.HeadersFrame{}, eventRecorder.Events(qlog.FrameParsed{})[0].(qlog.FrameParsed).Frame.Frame) fp := eventRecorder.Events(qlog.FrameParsed{})[0].(qlog.FrameParsed) require.Equal(t, quic.StreamID(0), fp.StreamID) require.NotZero(t, fp.Raw.PayloadLength) require.Contains(t, fp.Frame.Frame.(qlog.HeadersFrame).HeaderFields, qlog.HeaderField{Name: ":method", Value: "GET"}) require.Contains(t, fp.Frame.Frame.(qlog.HeadersFrame).HeaderFields, qlog.HeaderField{Name: ":authority", Value: "www.example.com"}) events := filterQlogEventsForFrame(eventRecorder.Events(qlog.FrameCreated{}), qlog.HeadersFrame{}) require.Len(t, events, 1) fc := events[0].(qlog.FrameCreated) require.Equal(t, quic.StreamID(0), fp.StreamID) require.NotZero(t, fc.Raw.PayloadLength) require.Contains(t, fc.Frame.Frame.(qlog.HeadersFrame).HeaderFields, qlog.HeaderField{Name: ":status", Value: "200"}) }) t.Run("content-length", func(t *testing.T) { hfs, body := testServerRequestHandling(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) w.Write([]byte("foobar")) }, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil), nil, ) require.Equal(t, hfs[":status"], []string{"418"}) require.Equal(t, hfs["content-length"], []string{"6"}) require.Equal(t, body, []byte("foobar")) }) t.Run("no content-length when flushed", func(t *testing.T) { hfs, body := testServerRequestHandling(t, func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("foo")) w.(http.Flusher).Flush() w.Write([]byte("bar")) }, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil), nil, ) require.Equal(t, hfs[":status"], []string{"200"}) require.NotContains(t, hfs, "content-length") require.Equal(t, body, []byte("foobar")) }) t.Run("HEAD request", func(t *testing.T) { hfs, body := testServerRequestHandling(t, func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("foobar")) }, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil), nil, ) require.Equal(t, hfs[":status"], []string{"200"}) require.Empty(t, body) }) t.Run("POST request", func(t *testing.T) { hfs, body := testServerRequestHandling(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) data, _ := io.ReadAll(r.Body) w.Write(data) }, httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))), nil, ) require.Equal(t, hfs[":status"], []string{"418"}) require.Equal(t, []byte("foobar"), body) }) } func testServerRequestHandling(t *testing.T, handler http.HandlerFunc, req *http.Request, rec qlogwriter.Recorder, ) (responseHeaders map[string][]string, body []byte) { clientConn, serverConn := newConnPair(t, withServerRecorder(rec)) str, err := clientConn.OpenStream() require.NoError(t, err) _, err = str.Write(encodeRequest(t, req)) require.NoError(t, err) require.NoError(t, str.Close()) s := &Server{Handler: handler} go s.ServeQUICConn(serverConn) hfs := decodeHeader(t, str) fp := frameParser{r: str} var content []byte for { frame, err := fp.ParseNext(nil) if err == io.EOF { break } require.NoError(t, err) require.IsType(t, &dataFrame{}, frame) b := make([]byte, frame.(*dataFrame).Length) _, err = io.ReadFull(str, b) require.NoError(t, err) content = append(content, b...) } return hfs, content } func TestServerFirstFrameNotHeaders(t *testing.T) { clientConn, serverConn := newConnPair(t) str, err := clientConn.OpenStream() require.NoError(t, err) var buf bytes.Buffer buf.Write((&dataFrame{Length: 6}).Append(nil)) buf.Write([]byte("foobar")) _, err = str.Write(buf.Bytes()) require.NoError(t, err) require.NoError(t, str.Close()) s := &Server{} go s.ServeQUICConn(serverConn) select { case <-clientConn.Context().Done(): err := context.Cause(clientConn.Context()) var appErr *quic.ApplicationError require.ErrorAs(t, err, &appErr) require.Equal(t, quic.ApplicationErrorCode(ErrCodeFrameUnexpected), appErr.ErrorCode) case <-time.After(time.Second): t.Fatal("timeout") } } func TestServerHandlerBodyNotRead(t *testing.T) { t.Run("GET request with a body", func(t *testing.T) { testServerHandlerBodyNotRead(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))), func(w http.ResponseWriter, r *http.Request) {}, ) }) t.Run("POST body not read", func(t *testing.T) { testServerHandlerBodyNotRead(t, httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))), func(w http.ResponseWriter, r *http.Request) {}, ) }) t.Run("POST request, with a replaced body", func(t *testing.T) { testServerHandlerBodyNotRead(t, httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))), func(w http.ResponseWriter, r *http.Request) { r.Body = struct { io.Reader io.Closer }{} }, ) }) } func testServerHandlerBodyNotRead(t *testing.T, req *http.Request, handler http.HandlerFunc) { clientConn, serverConn := newConnPair(t) str, err := clientConn.OpenStream() require.NoError(t, err) _, err = str.Write(encodeRequest(t, req)) require.NoError(t, err) done := make(chan struct{}) s := &Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer close(done) handler(w, r) }), } go s.ServeQUICConn(serverConn) select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } func TestServerStreamResetByClient(t *testing.T) { clientConn, serverConn := newConnPair(t) str, err := clientConn.OpenStream() require.NoError(t, err) str.CancelWrite(1337) var called bool s := &Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true }), } go s.ServeQUICConn(serverConn) expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete)) require.False(t, called) } func TestServerPanickingHandler(t *testing.T) { t.Run("panicking handler", func(t *testing.T) { logOutput := testServerPanickingHandler(t, func(w http.ResponseWriter, r *http.Request) { panic("foobar") }) require.Contains(t, logOutput, "http3: panic serving") require.Contains(t, logOutput, "foobar") }) t.Run("http.ErrAbortHandler", func(t *testing.T) { logOutput := testServerPanickingHandler(t, func(w http.ResponseWriter, r *http.Request) { panic(http.ErrAbortHandler) }) require.NotContains(t, logOutput, "http3: panic serving") require.NotContains(t, logOutput, "http.ErrAbortHandler") }) } func testServerPanickingHandler(t *testing.T, handler http.HandlerFunc) (logOutput string) { clientConn, serverConn := newConnPair(t) str, err := clientConn.OpenStream() require.NoError(t, err) _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil))) require.NoError(t, err) require.NoError(t, str.Close()) var logBuf bytes.Buffer s := &Server{ Handler: handler, Logger: slog.New(slog.NewTextHandler(&logBuf, nil)), } go s.ServeQUICConn(serverConn) expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeInternalError)) s.Close() return logBuf.String() } func TestServerRequestHeaderTooLarge(t *testing.T) { t.Run("default value", func(t *testing.T) { var eventRecorder events.Recorder // use 2*DefaultMaxHeaderBytes here. qpack will compress the request, // but the request will still end up larger than DefaultMaxHeaderBytes. url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) testServerRequestHeaderTooLarge(t, httptest.NewRequest(http.MethodGet, "https://"+string(url), nil), 0, &eventRecorder, ) events := eventRecorder.Events(qlog.FrameParsed{}) require.Len(t, events, 1) require.Equal(t, qlog.HeadersFrame{}, events[0].(qlog.FrameParsed).Frame.Frame) // The request is QPACK-compressed, so it will be smaller than 2*http.DefaultMaxHeaderBytes require.Greater(t, events[0].(qlog.FrameParsed).Raw.PayloadLength, http.DefaultMaxHeaderBytes) require.Less(t, events[0].(qlog.FrameParsed).Raw.PayloadLength, http.DefaultMaxHeaderBytes*2) }) t.Run("custom value", func(t *testing.T) { var eventRecorder events.Recorder testServerRequestHeaderTooLarge(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil), 20, &eventRecorder, ) events := eventRecorder.Events(qlog.FrameParsed{}) require.Len(t, events, 1) require.Equal(t, qlog.HeadersFrame{}, events[0].(qlog.FrameParsed).Frame.Frame) require.Greater(t, events[0].(qlog.FrameParsed).Raw.PayloadLength, 20) require.Less(t, events[0].(qlog.FrameParsed).Raw.PayloadLength, 40) }) } func testServerRequestHeaderTooLarge(t *testing.T, req *http.Request, maxHeaderBytes int, rec qlogwriter.Recorder) { var called bool s := &Server{ MaxHeaderBytes: maxHeaderBytes, Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }), } s.init() clientConn, serverConn := newConnPair(t, withServerRecorder(rec)) str, err := clientConn.OpenStream() require.NoError(t, err) _, err = str.Write(encodeRequest(t, req)) require.NoError(t, err) go s.ServeQUICConn(serverConn) hfs := decodeHeader(t, str) require.Equal(t, []string{"431"}, hfs[":status"]) expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeExcessiveLoad)) require.False(t, called) } func TestServerRequestContext(t *testing.T) { clientConn, serverConn := newConnPair(t) str, err := clientConn.OpenStream() require.NoError(t, err) _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil))) require.NoError(t, err) ctxChan := make(chan context.Context, 1) block := make(chan struct{}) s := &Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctxChan <- r.Context() <-block }), } go s.ServeQUICConn(serverConn) var requestContext context.Context select { case requestContext = <-ctxChan: case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, s, requestContext.Value(ServerContextKey)) require.Equal(t, serverConn.LocalAddr(), requestContext.Value(http.LocalAddrContextKey)) require.Equal(t, serverConn.RemoteAddr(), requestContext.Value(RemoteAddrContextKey)) select { case <-requestContext.Done(): t.Fatal("request context was canceled") case <-time.After(scaleDuration(10 * time.Millisecond)): } str.CancelRead(1337) select { case <-requestContext.Done(): case <-time.After(time.Second): t.Fatal("timeout") } require.Equal(t, context.Canceled, requestContext.Err()) close(block) } func TestServerHTTPStreamHijacking(t *testing.T) { clientConn, serverConn := newConnPair(t) str, err := clientConn.OpenStream() require.NoError(t, err) _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil))) require.NoError(t, err) require.NoError(t, str.Close()) s := &Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { str := w.(HTTPStreamer).HTTPStream() str.Write([]byte("foobar")) str.Close() }), } go s.ServeQUICConn(serverConn) str.SetReadDeadline(time.Now().Add(time.Second)) rsp, err := io.ReadAll(str) require.NoError(t, err) r := bytes.NewReader(rsp) hfs := decodeHeader(t, r) require.Equal(t, hfs[":status"], []string{"200"}) fp := frameParser{r: r} frame, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &dataFrame{}, frame) dataFrame := frame.(*dataFrame) require.Equal(t, uint64(6), dataFrame.Length) data, err := io.ReadAll(r) require.NoError(t, err) require.Equal(t, []byte("foobar"), data) } func getAltSvc(s *Server) (string, bool) { hdr := http.Header{} s.SetQUICHeaders(hdr) if altSvc, ok := hdr["Alt-Svc"]; ok { return altSvc[0], true } return "", false } func TestServerAltSvcFromListenersAndConns(t *testing.T) { t.Run("default", func(t *testing.T) { testServerAltSvcFromListenersAndConns(t, []quic.Version{}) }) t.Run("v1", func(t *testing.T) { testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1}) }) t.Run("v1 and v2", func(t *testing.T) { testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1, quic.Version2}) }) } func testServerAltSvcFromListenersAndConns(t *testing.T, versions []quic.Version) { ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil) require.NoError(t, err) port1 := ln1.Addr().(*net.UDPAddr).Port s := &Server{ Addr: ":1337", // will be ignored since we're using listeners TLSConfig: getTLSConfig(), QUICConfig: &quic.Config{Versions: versions}, } done1 := make(chan struct{}) go func() { defer close(done1) s.ServeListener(ln1) }() time.Sleep(scaleDuration(10 * time.Millisecond)) altSvc, ok := getAltSvc(s) require.True(t, ok) require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port1), altSvc) udpConn := newUDPConnLocalhost(t) port2 := udpConn.LocalAddr().(*net.UDPAddr).Port done2 := make(chan struct{}) go func() { defer close(done2) s.Serve(udpConn) }() time.Sleep(scaleDuration(10 * time.Millisecond)) altSvc, ok = getAltSvc(s) require.True(t, ok) require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000,h3=":%d"; ma=2592000`, port1, port2), altSvc) // Close the first listener. // This should remove the associated Alt-Svc entry. require.NoError(t, ln1.Close()) select { case <-done1: case <-time.After(time.Second): t.Fatal("timeout") } altSvc, ok = getAltSvc(s) require.True(t, ok) require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port2), altSvc) // Close the second listener. // This should remove the Alt-Svc entry altogether. require.NoError(t, udpConn.Close()) select { case <-done2: case <-time.After(time.Second): t.Fatal("timeout") } _, ok = getAltSvc(s) require.False(t, ok) } func TestServerAltSvcFromPort(t *testing.T) { s := &Server{Port: 1337} _, ok := getAltSvc(s) require.False(t, ok) ln, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) s.ServeListener(ln) }() time.Sleep(scaleDuration(10 * time.Millisecond)) altSvc, ok := getAltSvc(s) require.True(t, ok) require.Equal(t, `h3=":1337"; ma=2592000`, altSvc) require.NoError(t, ln.Close()) select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } _, ok = getAltSvc(s) require.False(t, ok) } type unixSocketListener struct { *quic.EarlyListener } func (l *unixSocketListener) Addr() net.Addr { return &net.UnixAddr{Net: "unix", Name: "/tmp/quic.sock"} } func TestServerAltSvcFromUnixSocket(t *testing.T) { t.Run("with Server.Addr not set", func(t *testing.T) { _, ok := testServerAltSvcFromUnixSocket(t, "") require.False(t, ok) }) t.Run("with Server.Addr set", func(t *testing.T) { altSvc, ok := testServerAltSvcFromUnixSocket(t, ":1337") require.True(t, ok) require.Equal(t, `h3=":1337"; ma=2592000`, altSvc) }) } func testServerAltSvcFromUnixSocket(t *testing.T, addr string) (altSvc string, ok bool) { ln, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil) require.NoError(t, err) var logBuf bytes.Buffer s := &Server{ Addr: addr, Logger: slog.New(slog.NewTextHandler(&logBuf, nil)), } done := make(chan struct{}) go func() { defer close(done) s.ServeListener(&unixSocketListener{EarlyListener: ln}) }() time.Sleep(scaleDuration(10 * time.Millisecond)) altSvc, ok = getAltSvc(s) require.NoError(t, ln.Close()) select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } require.Contains(t, logBuf.String(), "Unable to extract port from listener, will not be announced using SetQUICHeaders") return altSvc, ok } func TestServerListenAndServeErrors(t *testing.T) { require.EqualError(t, (&Server{}).ListenAndServe(), "use of http3.Server without TLSConfig") s := &Server{ Addr: ":123456", TLSConfig: testdata.GetTLSConfig(), } require.ErrorContains(t, s.ListenAndServe(), "invalid port") } func TestServerClosing(t *testing.T) { s := &Server{TLSConfig: getTLSConfig()} require.NoError(t, s.Close()) require.NoError(t, s.Close()) // duplicate calls are ok require.ErrorIs(t, s.ListenAndServe(), http.ErrServerClosed) require.ErrorIs(t, s.ListenAndServeTLS(testdata.GetCertificatePaths()), http.ErrServerClosed) require.ErrorIs(t, s.Serve(nil), http.ErrServerClosed) require.ErrorIs(t, s.ServeListener(nil), http.ErrServerClosed) require.ErrorIs(t, s.ServeQUICConn(nil), http.ErrServerClosed) } func TestServerConcurrentServeAndClose(t *testing.T) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") require.NoError(t, err) c, err := net.ListenUDP("udp", addr) require.NoError(t, err) done := make(chan struct{}) s := &Server{TLSConfig: testdata.GetTLSConfig()} go func() { defer close(done) s.Serve(c) }() runtime.Gosched() s.Close() select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } func TestServerImmediateGracefulShutdown(t *testing.T) { s := &Server{TLSConfig: testdata.GetTLSConfig()} errChan := make(chan error, 1) go func() { errChan <- s.Shutdown(context.Background()) }() select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestServerGracefulShutdown(t *testing.T) { requestChan := make(chan struct{}, 1) s := &Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestChan <- struct{}{} })} clientConn, serverConn := newConnPair(t) go s.ServeQUICConn(serverConn) firstStream, err := clientConn.OpenStream() require.NoError(t, err) _, err = firstStream.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil))) require.NoError(t, err) select { case <-requestChan: case <-time.After(time.Second): t.Fatal("timeout") } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() controlStr, err := clientConn.AcceptUniStream(ctx) require.NoError(t, err) typ, err := quicvarint.Read(quicvarint.NewReader(controlStr)) require.NoError(t, err) require.EqualValues(t, streamTypeControlStream, typ) fp := &frameParser{r: controlStr} f, err := fp.ParseNext(nil) require.NoError(t, err) require.IsType(t, &settingsFrame{}, f) shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) errChan := make(chan error) go func() { errChan <- s.Shutdown(shutdownCtx) }() f, err = fp.ParseNext(nil) require.NoError(t, err) require.Equal(t, &goAwayFrame{StreamID: 4}, f) select { case <-errChan: t.Fatal("didn't expect Shutdown to return") case <-time.After(scaleDuration(10 * time.Millisecond)): } // all further streams are getting rejected for range 3 { str, err := clientConn.OpenStream() require.NoError(t, err) _, _ = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil))) expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected)) expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected)) } // cancel the context passed to Shutdown shutdownCancel() select { case err := <-errChan: require.ErrorIs(t, err, context.Canceled) case <-time.After(time.Second): t.Fatal("timeout") } } quic-go-0.59.0/http3/state_tracking_stream.go000066400000000000000000000072451513066070600211330ustar00rootroot00000000000000package http3 import ( "context" "errors" "os" "sync" "github.com/quic-go/quic-go" ) const streamDatagramQueueLen = 32 // stateTrackingStream is an implementation of quic.Stream that delegates // to an underlying stream // it takes care of proxying send and receive errors onto an implementation of // the errorSetter interface (intended to be occupied by a datagrammer) // it is also responsible for clearing the stream based on its ID from its // parent connection, this is done through the streamClearer interface when // both the send and receive sides are closed type stateTrackingStream struct { *quic.Stream sendDatagram func([]byte) error hasData chan struct{} queue [][]byte // TODO: use a ring buffer mx sync.Mutex sendErr error recvErr error clearer streamClearer } var _ datagramStream = &stateTrackingStream{} type streamClearer interface { clearStream(quic.StreamID) } func newStateTrackingStream(s *quic.Stream, clearer streamClearer, sendDatagram func([]byte) error) *stateTrackingStream { t := &stateTrackingStream{ Stream: s, clearer: clearer, sendDatagram: sendDatagram, hasData: make(chan struct{}, 1), } context.AfterFunc(s.Context(), func() { t.closeSend(context.Cause(s.Context())) }) return t } func (s *stateTrackingStream) closeSend(e error) { s.mx.Lock() defer s.mx.Unlock() // clear the stream the first time both the send // and receive are finished if s.sendErr == nil { if s.recvErr != nil { s.clearer.clearStream(s.StreamID()) } s.sendErr = e } } func (s *stateTrackingStream) closeReceive(e error) { s.mx.Lock() defer s.mx.Unlock() // clear the stream the first time both the send // and receive are finished if s.recvErr == nil { if s.sendErr != nil { s.clearer.clearStream(s.StreamID()) } s.recvErr = e s.signalHasDatagram() } } func (s *stateTrackingStream) Close() error { s.closeSend(errors.New("write on closed stream")) return s.Stream.Close() } func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { s.closeSend(&quic.StreamError{StreamID: s.StreamID(), ErrorCode: e}) s.Stream.CancelWrite(e) } func (s *stateTrackingStream) Write(b []byte) (int, error) { n, err := s.Stream.Write(b) if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { s.closeSend(err) } return n, err } func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { s.closeReceive(&quic.StreamError{StreamID: s.StreamID(), ErrorCode: e}) s.Stream.CancelRead(e) } func (s *stateTrackingStream) Read(b []byte) (int, error) { n, err := s.Stream.Read(b) if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { s.closeReceive(err) } return n, err } func (s *stateTrackingStream) SendDatagram(b []byte) error { s.mx.Lock() sendErr := s.sendErr s.mx.Unlock() if sendErr != nil { return sendErr } return s.sendDatagram(b) } func (s *stateTrackingStream) signalHasDatagram() { select { case s.hasData <- struct{}{}: default: } } func (s *stateTrackingStream) enqueueDatagram(data []byte) { s.mx.Lock() defer s.mx.Unlock() if s.recvErr != nil { return } if len(s.queue) >= streamDatagramQueueLen { return } s.queue = append(s.queue, data) s.signalHasDatagram() } func (s *stateTrackingStream) ReceiveDatagram(ctx context.Context) ([]byte, error) { start: s.mx.Lock() if len(s.queue) > 0 { data := s.queue[0] s.queue = s.queue[1:] s.mx.Unlock() return data, nil } if receiveErr := s.recvErr; receiveErr != nil { s.mx.Unlock() return nil, receiveErr } s.mx.Unlock() select { case <-ctx.Done(): return nil, context.Cause(ctx) case <-s.hasData: } goto start } func (s *stateTrackingStream) QUICStream() *quic.Stream { return s.Stream } quic-go-0.59.0/http3/state_tracking_stream_test.go000066400000000000000000000212471513066070600221700ustar00rootroot00000000000000package http3 import ( "context" "io" "net" "os" "testing" "time" "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func newStreamPair(t *testing.T) (client, server *quic.Stream) { t.Helper() clientConn, serverConn := newConnPair(t) serverStr, err := serverConn.OpenStream() require.NoError(t, err) // need to send something to the client to make it accept the stream _, err = serverStr.Write([]byte{0}) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientStr, err := clientConn.AcceptStream(ctx) require.NoError(t, err) clientStr.SetReadDeadline(time.Now().Add(time.Second)) _, err = clientStr.Read([]byte{0}) require.NoError(t, err) clientStr.SetWriteDeadline(time.Time{}) return clientStr, serverStr } func canceledCtx() context.Context { ctx, cancel := context.WithCancel(context.Background()) cancel() return ctx } func checkDatagramReceive(t *testing.T, str *stateTrackingStream) { t.Helper() _, err := str.ReceiveDatagram(canceledCtx()) require.ErrorIs(t, err, context.Canceled) } func checkDatagramSend(t *testing.T, str *stateTrackingStream) { t.Helper() require.NoError(t, str.SendDatagram([]byte("test"))) } type mockStreamClearer struct { cleared *quic.StreamID } func (s *mockStreamClearer) clearStream(id quic.StreamID) { s.cleared = &id } func TestStateTrackingStreamRead(t *testing.T) { t.Run("io.EOF", func(t *testing.T) { testStateTrackingStreamRead(t, false) }) t.Run("remote stream reset", func(t *testing.T) { testStateTrackingStreamRead(t, true) }) } func testStateTrackingStreamRead(t *testing.T, reset bool) { client, server := newStreamPair(t) var clearer mockStreamClearer str := newStateTrackingStream(client, &clearer, func(b []byte) error { return nil }) // deadline errors are ignored client.SetReadDeadline(time.Now()) _, err := str.Read(make([]byte, 3)) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Nil(t, clearer.cleared) client.SetReadDeadline(time.Time{}) _, err = server.Write([]byte("foobar")) require.NoError(t, err) if !reset { server.Close() for range 3 { _, err := str.Read([]byte{0}) require.NoError(t, err) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) } } else { server.CancelWrite(42) } var expectedErr error _, err = io.ReadAll(str) if !reset { require.NoError(t, err) expectedErr = io.EOF } else { expectedErr = &quic.StreamError{Remote: true, StreamID: server.StreamID(), ErrorCode: 42} require.ErrorIs(t, err, expectedErr) } require.Nil(t, clearer.cleared) // the receive side registered the error _, err = str.ReceiveDatagram(canceledCtx()) require.ErrorIs(t, err, expectedErr) // the send side is still open require.NoError(t, str.SendDatagram([]byte("foo"))) } func TestStateTrackingStreamRemoteCancelation(t *testing.T) { client, server := newStreamPair(t) var clearer mockStreamClearer str := newStateTrackingStream(client, &clearer, func(b []byte) error { return nil }) _, err := str.Write([]byte("foo")) require.NoError(t, err) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) checkDatagramSend(t, str) // deadline errors are ignored client.SetWriteDeadline(time.Now()) _, err = str.Write([]byte("baz")) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) checkDatagramSend(t, str) client.SetWriteDeadline(time.Time{}) server.CancelRead(123) var writeErr error require.Eventually(t, func() bool { _, writeErr = str.Write([]byte("bar")) return writeErr != nil }, time.Second, scaleDuration(time.Millisecond)) expectedErr := &quic.StreamError{Remote: true, StreamID: server.StreamID(), ErrorCode: 123} require.ErrorIs(t, writeErr, expectedErr) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) require.ErrorIs(t, str.SendDatagram([]byte("test")), expectedErr) } func TestStateTrackingStreamLocalCancelation(t *testing.T) { client, _ := newStreamPair(t) var clearer mockStreamClearer str := newStateTrackingStream(client, &clearer, func(b []byte) error { return nil }) _, err := str.Write([]byte("foobar")) require.NoError(t, err) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) checkDatagramSend(t, str) str.CancelWrite(1337) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) require.ErrorIs(t, str.SendDatagram([]byte("test")), &quic.StreamError{StreamID: client.StreamID(), ErrorCode: 1337}) } func TestStateTrackingStreamClose(t *testing.T) { client, _ := newStreamPair(t) var clearer mockStreamClearer str := newStateTrackingStream(client, &clearer, func(b []byte) error { return nil }) require.Nil(t, clearer.cleared) checkDatagramReceive(t, str) checkDatagramSend(t, str) require.NoError(t, client.Close()) require.Eventually(t, func() bool { err := str.SendDatagram([]byte("test")) if err == nil { return false } require.ErrorIs(t, err, context.Canceled) return true }, time.Second, scaleDuration(5*time.Millisecond)) checkDatagramReceive(t, str) require.Nil(t, clearer.cleared) } func TestStateTrackingStreamReceiveThenSend(t *testing.T) { client, server := newStreamPair(t) var clearer mockStreamClearer str := newStateTrackingStream(client, &clearer, func(b []byte) error { return nil }) _, err := server.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, server.Close()) _, err = io.ReadAll(str) require.NoError(t, err) require.Nil(t, clearer.cleared) _, err = str.ReceiveDatagram(canceledCtx()) require.ErrorIs(t, err, io.EOF) client.CancelWrite(123) id := client.StreamID() _, err = str.Write([]byte("bar")) require.ErrorIs(t, err, &quic.StreamError{StreamID: id, ErrorCode: 123}) require.ErrorIs(t, str.SendDatagram([]byte("test")), &quic.StreamError{StreamID: id, ErrorCode: 123}) require.Equal(t, &id, clearer.cleared) } func TestStateTrackingStreamSendThenReceive(t *testing.T) { client, server := newStreamPair(t) var clearer mockStreamClearer str := newStateTrackingStream(client, &clearer, func(b []byte) error { return nil }) server.CancelRead(1234) var writeErr error require.Eventually(t, func() bool { _, writeErr = str.Write([]byte("bar")) return writeErr != nil }, time.Second, scaleDuration(time.Millisecond)) id := server.StreamID() expectedErr := &quic.StreamError{Remote: true, StreamID: id, ErrorCode: 1234} require.ErrorIs(t, writeErr, expectedErr) require.Nil(t, clearer.cleared) require.ErrorIs(t, str.SendDatagram([]byte("test")), expectedErr) _, err := server.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, server.Close()) _, err = io.ReadAll(str) require.NoError(t, err) _, err = str.ReceiveDatagram(canceledCtx()) require.ErrorIs(t, err, io.EOF) require.Equal(t, &id, clearer.cleared) } func TestDatagramReceiving(t *testing.T) { client, _ := newStreamPair(t) str := newStateTrackingStream(client, nil, func(b []byte) error { return nil }) type result struct { data []byte err error } // Receive blocks until a datagram is received resultChan := make(chan result) go func() { defer close(resultChan) data, err := str.ReceiveDatagram(context.Background()) resultChan <- result{data: data, err: err} }() select { case <-time.After(scaleDuration(10 * time.Millisecond)): case <-resultChan: t.Fatal("should not have received a datagram") } str.enqueueDatagram([]byte("foobar")) select { case res := <-resultChan: require.NoError(t, res.err) require.Equal(t, []byte("foobar"), res.data) case <-time.After(time.Second): t.Fatal("should have received a datagram") } // up to 32 datagrams can be queued for i := range streamDatagramQueueLen + 1 { str.enqueueDatagram([]byte{uint8(i)}) } for i := range streamDatagramQueueLen { data, err := str.ReceiveDatagram(context.Background()) require.NoError(t, err) require.Equal(t, []byte{uint8(i)}, data) } // Receive respects the context ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := str.ReceiveDatagram(ctx) require.ErrorIs(t, err, context.Canceled) } func TestDatagramSending(t *testing.T) { var sendQueue [][]byte errors := []error{nil, nil, assert.AnError} client, _ := newStreamPair(t) str := newStateTrackingStream(client, nil, func(b []byte) error { sendQueue = append(sendQueue, b) err := errors[0] errors = errors[1:] return err }) require.NoError(t, str.SendDatagram([]byte("foo"))) require.NoError(t, str.SendDatagram([]byte("bar"))) require.ErrorIs(t, str.SendDatagram([]byte("baz")), assert.AnError) require.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, sendQueue) str.closeSend(net.ErrClosed) require.ErrorIs(t, str.SendDatagram([]byte("foobar")), net.ErrClosed) } quic-go-0.59.0/http3/stream.go000066400000000000000000000312331513066070600160430ustar00rootroot00000000000000package http3 import ( "context" "errors" "fmt" "io" "net/http" "net/http/httptrace" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/qpack" ) type datagramStream interface { io.ReadWriteCloser CancelRead(quic.StreamErrorCode) CancelWrite(quic.StreamErrorCode) StreamID() quic.StreamID Context() context.Context SetDeadline(time.Time) error SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error SendDatagram(b []byte) error ReceiveDatagram(ctx context.Context) ([]byte, error) QUICStream() *quic.Stream } // A Stream is an HTTP/3 stream. // // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. type Stream struct { datagramStream conn *rawConn frameParser *frameParser buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers bytesRemainingInFrame uint64 qlogger qlogwriter.Recorder parseTrailer func(io.Reader, *headersFrame) error parsedTrailer bool } func newStream( str datagramStream, conn *rawConn, trace *httptrace.ClientTrace, parseTrailer func(io.Reader, *headersFrame) error, qlogger qlogwriter.Recorder, ) *Stream { return &Stream{ datagramStream: str, conn: conn, buf: make([]byte, 16), qlogger: qlogger, parseTrailer: parseTrailer, frameParser: &frameParser{ r: &tracingReader{Reader: str, trace: trace}, streamID: str.StreamID(), closeConn: conn.CloseWithError, }, } } func (s *Stream) Read(b []byte) (int, error) { if s.bytesRemainingInFrame == 0 { parseLoop: for { frame, err := s.frameParser.ParseNext(s.qlogger) if err != nil { return 0, err } switch f := frame.(type) { case *dataFrame: if s.parsedTrailer { return 0, errors.New("DATA frame received after trailers") } s.bytesRemainingInFrame = f.Length break parseLoop case *headersFrame: if s.parsedTrailer { maybeQlogInvalidHeadersFrame(s.qlogger, s.StreamID(), f.Length) return 0, errors.New("additional HEADERS frame received after trailers") } s.parsedTrailer = true return 0, s.parseTrailer(s.datagramStream, f) default: s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") // parseNextFrame skips over unknown frame types // Therefore, this condition is only entered when we parsed another known frame type. return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) } } } var n int var err error if s.bytesRemainingInFrame < uint64(len(b)) { n, err = s.datagramStream.Read(b[:s.bytesRemainingInFrame]) } else { n, err = s.datagramStream.Read(b) } s.bytesRemainingInFrame -= uint64(n) return n, err } func (s *Stream) hasMoreData() bool { return s.bytesRemainingInFrame > 0 } func (s *Stream) Write(b []byte) (int, error) { s.buf = s.buf[:0] s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.FrameCreated{ StreamID: s.StreamID(), Raw: qlog.RawInfo{ Length: len(s.buf) + len(b), PayloadLength: len(b), }, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }) } if _, err := s.datagramStream.Write(s.buf); err != nil { return 0, err } return s.datagramStream.Write(b) } func (s *Stream) writeUnframed(b []byte) (int, error) { return s.datagramStream.Write(b) } func (s *Stream) StreamID() quic.StreamID { return s.datagramStream.StreamID() } func (s *Stream) SendDatagram(b []byte) error { // TODO: reject if datagrams are not negotiated (yet) return s.datagramStream.SendDatagram(b) } func (s *Stream) ReceiveDatagram(ctx context.Context) ([]byte, error) { // TODO: reject if datagrams are not negotiated (yet) return s.datagramStream.ReceiveDatagram(ctx) } // A RequestStream is a low-level abstraction representing an HTTP/3 request stream. // It decouples sending of the HTTP request from reading the HTTP response, allowing // the application to optimistically use the stream (and, for example, send datagrams) // before receiving the response. // // This is only needed for advanced use case, e.g. WebTransport and the various // MASQUE proxying protocols. type RequestStream struct { str *Stream responseBody io.ReadCloser // set by ReadResponse decoder *qpack.Decoder requestWriter *requestWriter maxHeaderBytes int reqDone chan<- struct{} disableCompression bool response *http.Response sentRequest bool requestedGzip bool isConnect bool } func newRequestStream( str *Stream, requestWriter *requestWriter, reqDone chan<- struct{}, decoder *qpack.Decoder, disableCompression bool, maxHeaderBytes int, rsp *http.Response, ) *RequestStream { return &RequestStream{ str: str, requestWriter: requestWriter, reqDone: reqDone, decoder: decoder, disableCompression: disableCompression, maxHeaderBytes: maxHeaderBytes, response: rsp, } } // Read reads data from the underlying stream. // // It can only be used after the request has been sent (using SendRequestHeader) // and the response has been consumed (using ReadResponse). func (s *RequestStream) Read(b []byte) (int, error) { if s.responseBody == nil { return 0, errors.New("http3: invalid use of RequestStream.Read before ReadResponse") } return s.responseBody.Read(b) } // StreamID returns the QUIC stream ID of the underlying QUIC stream. func (s *RequestStream) StreamID() quic.StreamID { return s.str.StreamID() } // Write writes data to the stream. // // It can only be used after the request has been sent (using SendRequestHeader). func (s *RequestStream) Write(b []byte) (int, error) { if !s.sentRequest { return 0, errors.New("http3: invalid use of RequestStream.Write before SendRequestHeader") } return s.str.Write(b) } // Close closes the send-direction of the stream. // It does not close the receive-direction of the stream. func (s *RequestStream) Close() error { return s.str.Close() } // CancelRead aborts receiving on this stream. // See [quic.Stream.CancelRead] for more details. func (s *RequestStream) CancelRead(errorCode quic.StreamErrorCode) { s.str.CancelRead(errorCode) } // CancelWrite aborts sending on this stream. // See [quic.Stream.CancelWrite] for more details. func (s *RequestStream) CancelWrite(errorCode quic.StreamErrorCode) { s.str.CancelWrite(errorCode) } // Context returns a context derived from the underlying QUIC stream's context. // See [quic.Stream.Context] for more details. func (s *RequestStream) Context() context.Context { return s.str.Context() } // SetReadDeadline sets the deadline for Read calls. func (s *RequestStream) SetReadDeadline(t time.Time) error { return s.str.SetReadDeadline(t) } // SetWriteDeadline sets the deadline for Write calls. func (s *RequestStream) SetWriteDeadline(t time.Time) error { return s.str.SetWriteDeadline(t) } // SetDeadline sets the read and write deadlines associated with the stream. // It is equivalent to calling both SetReadDeadline and SetWriteDeadline. func (s *RequestStream) SetDeadline(t time.Time) error { return s.str.SetDeadline(t) } // SendDatagrams send a new HTTP Datagram (RFC 9297). // // It is only possible to send datagrams if the server enabled support for this extension. // It is recommended (though not required) to send the request before calling this method, // as the server might drop datagrams which it can't associate with an existing request. func (s *RequestStream) SendDatagram(b []byte) error { return s.str.SendDatagram(b) } // ReceiveDatagram receives HTTP Datagrams (RFC 9297). // // It is only possible if support for HTTP Datagrams was enabled, using the EnableDatagram // option on the [Transport]. func (s *RequestStream) ReceiveDatagram(ctx context.Context) ([]byte, error) { return s.str.ReceiveDatagram(ctx) } // SendRequestHeader sends the HTTP request. // // It can only used for requests that don't have a request body. // It is invalid to call it more than once. // It is invalid to call it after Write has been called. func (s *RequestStream) SendRequestHeader(req *http.Request) error { if req.Body != nil && req.Body != http.NoBody { return errors.New("http3: invalid use of RequestStream.SendRequestHeader with a request that has a request body") } return s.sendRequestHeader(req) } func (s *RequestStream) sendRequestHeader(req *http.Request) error { if s.sentRequest { return errors.New("http3: invalid duplicate use of RequestStream.SendRequestHeader") } if !s.disableCompression && req.Method != http.MethodHead && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { s.requestedGzip = true } s.isConnect = req.Method == http.MethodConnect s.sentRequest = true return s.requestWriter.WriteRequestHeader(s.str.datagramStream, req, s.requestedGzip, s.str.StreamID(), s.str.qlogger) } // sendRequestTrailer sends request trailers to the stream. // It should be called after the request body has been fully written. func (s *RequestStream) sendRequestTrailer(req *http.Request) error { return s.requestWriter.WriteRequestTrailer(s.str.datagramStream, req, s.str.StreamID(), s.str.qlogger) } // ReadResponse reads the HTTP response from the stream. // // It must be called after sending the request (using SendRequestHeader). // It is invalid to call it more than once. // It doesn't set Response.Request and Response.TLS. // It is invalid to call it after Read has been called. func (s *RequestStream) ReadResponse() (*http.Response, error) { if !s.sentRequest { return nil, errors.New("http3: invalid use of RequestStream.ReadResponse before SendRequestHeader") } frame, err := s.str.frameParser.ParseNext(s.str.qlogger) if err != nil { s.str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) s.str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) return nil, fmt.Errorf("http3: parsing frame failed: %w", err) } hf, ok := frame.(*headersFrame) if !ok { s.str.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame") return nil, errors.New("http3: expected first frame to be a HEADERS frame") } if hf.Length > uint64(s.maxHeaderBytes) { maybeQlogInvalidHeadersFrame(s.str.qlogger, s.str.StreamID(), hf.Length) s.str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) s.str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes) } headerBlock := make([]byte, hf.Length) if _, err := io.ReadFull(s.str.datagramStream, headerBlock); err != nil { maybeQlogInvalidHeadersFrame(s.str.qlogger, s.str.StreamID(), hf.Length) s.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) s.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) return nil, fmt.Errorf("http3: failed to read response headers: %w", err) } decodeFn := s.decoder.Decode(headerBlock) var hfs []qpack.HeaderField if s.str.qlogger != nil { hfs = make([]qpack.HeaderField, 0, 16) } res := s.response err = updateResponseFromHeaders(res, decodeFn, s.maxHeaderBytes, &hfs) if s.str.qlogger != nil { qlogParsedHeadersFrame(s.str.qlogger, s.str.StreamID(), hf, hfs) } if err != nil { errCode := ErrCodeMessageError var qpackErr *qpackError if errors.As(err, &qpackErr) { errCode = ErrCodeQPACKDecompressionFailed } s.str.CancelRead(quic.StreamErrorCode(errCode)) s.str.CancelWrite(quic.StreamErrorCode(errCode)) return nil, fmt.Errorf("http3: invalid response: %w", err) } // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). // See section 4.1.2 of RFC 9114. respBody := newResponseBody(s.str, res.ContentLength, s.reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. isInformational := res.StatusCode >= 100 && res.StatusCode < 200 isNoContent := res.StatusCode == http.StatusNoContent isSuccessfulConnect := s.isConnect && res.StatusCode >= 200 && res.StatusCode < 300 if (isInformational || isNoContent || isSuccessfulConnect) && res.ContentLength == -1 { res.ContentLength = 0 } if s.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 s.responseBody = newGzipReader(respBody) res.Uncompressed = true } else { s.responseBody = respBody } res.Body = s.responseBody return res, nil } type tracingReader struct { io.Reader readFirst bool trace *httptrace.ClientTrace } func (r *tracingReader) Read(b []byte) (int, error) { n, err := r.Reader.Read(b) if n > 0 && !r.readFirst { traceGotFirstResponseByte(r.trace) r.readFirst = true } return n, err } quic-go-0.59.0/http3/stream_test.go000066400000000000000000000155511513066070600171070ustar00rootroot00000000000000package http3 import ( "bytes" "context" "io" "math" "net/http" "net/http/httptest" "net/http/httptrace" "strings" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/quic-go/qpack" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func getDataFrame(data []byte) []byte { b := (&dataFrame{Length: uint64(len(data))}).Append(nil) return append(b, data...) } func TestStreamReadDataFrames(t *testing.T) { var buf bytes.Buffer mockCtrl := gomock.NewController(t) qstr := NewMockDatagramStream(mockCtrl) qstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() var eventRecorder events.Recorder clientConn, _ := newConnPair(t, withClientRecorder(&eventRecorder)) str := newStream( qstr, newRawConn(clientConn, false, nil, nil, &eventRecorder, nil), nil, func(io.Reader, *headersFrame) error { return nil }, &eventRecorder, ) buf.Write(getDataFrame([]byte("foobar"))) b := make([]byte, 3) n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("foo"), b) n, err = str.Read(b) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("bar"), b) expectedLen, _ := expectedFrameLength(t, &dataFrame{Length: 6}) require.Equal(t, []qlogwriter.Event{ qlog.FrameParsed{ StreamID: 42, Raw: qlog.RawInfo{Length: expectedLen, PayloadLength: 6}, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }, }, eventRecorder.Events(qlog.FrameParsed{}), ) eventRecorder.Clear() buf.Write(getDataFrame([]byte("baz"))) b = make([]byte, 10) n, err = str.Read(b) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("baz"), b[:n]) require.Len(t, eventRecorder.Events(qlog.FrameParsed{}), 1) eventRecorder.Clear() buf.Write(getDataFrame([]byte("lorem"))) buf.Write(getDataFrame([]byte("ipsum"))) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, "loremipsum", string(data)) require.Len(t, eventRecorder.Events(qlog.FrameParsed{}), 2) eventRecorder.Clear() // invalid frame buf.Write([]byte("invalid")) _, err = str.Read([]byte{0}) require.Error(t, err) } func TestStreamInvalidFrame(t *testing.T) { var buf bytes.Buffer b := (&settingsFrame{}).Append(nil) buf.Write(b) mockCtrl := gomock.NewController(t) qstr := NewMockDatagramStream(mockCtrl) qstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() clientConn, serverConn := newConnPair(t) str := newStream( qstr, newRawConn(clientConn, false, nil, nil, nil, nil), nil, func(io.Reader, *headersFrame) error { return nil }, nil, ) _, err := str.Read([]byte{0}) require.ErrorContains(t, err, "peer sent an unexpected frame") select { case <-serverConn.Context().Done(): var appErr *quic.ApplicationError require.ErrorAs(t, context.Cause(serverConn.Context()), &appErr) require.Equal(t, quic.ApplicationErrorCode(ErrCodeFrameUnexpected), appErr.ErrorCode) case <-time.After(time.Second): t.Fatal("timeout") } } func TestStreamWrite(t *testing.T) { var buf bytes.Buffer mockCtrl := gomock.NewController(t) qstr := NewMockDatagramStream(mockCtrl) qstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() var eventRecorder events.Recorder str := newStream(qstr, nil, nil, func(io.Reader, *headersFrame) error { return nil }, &eventRecorder) str.Write([]byte("foo")) str.Write([]byte("foobar")) fp := frameParser{r: &buf} f, err := fp.ParseNext(nil) require.NoError(t, err) f1Len, f1PayloadLen := expectedFrameLength(t, &dataFrame{Length: 3}) require.Equal(t, &dataFrame{Length: 3}, f) b := make([]byte, 3) _, err = io.ReadFull(&buf, b) require.NoError(t, err) require.Equal(t, []byte("foo"), b) fp = frameParser{r: &buf} f, err = fp.ParseNext(nil) require.NoError(t, err) f2Len, f2PayloadLen := expectedFrameLength(t, &dataFrame{Length: 6}) require.Equal(t, &dataFrame{Length: 6}, f) b = make([]byte, 6) _, err = io.ReadFull(&buf, b) require.NoError(t, err) require.Equal(t, []byte("foobar"), b) require.Equal(t, []qlogwriter.Event{ qlog.FrameCreated{ StreamID: 42, Raw: qlog.RawInfo{Length: f1Len, PayloadLength: f1PayloadLen}, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }, qlog.FrameCreated{ StreamID: 42, Raw: qlog.RawInfo{Length: f2Len, PayloadLength: f2PayloadLen}, Frame: qlog.Frame{Frame: qlog.DataFrame{}}, }, }, eventRecorder.Events(qlog.FrameCreated{}), ) } func TestRequestStream(t *testing.T) { mockCtrl := gomock.NewController(t) qstr := NewMockDatagramStream(mockCtrl) qstr.EXPECT().StreamID().Return(quic.StreamID(42)).AnyTimes() requestWriter := newRequestWriter() clientConn, _ := newConnPair(t) str := newRequestStream( newStream( qstr, newRawConn(clientConn, false, nil, nil, nil, nil), &httptrace.ClientTrace{}, func(io.Reader, *headersFrame) error { return nil }, nil, ), requestWriter, make(chan struct{}), qpack.NewDecoder(), true, math.MaxInt, &http.Response{}, ) _, err := str.Read([]byte{0}) require.EqualError(t, err, "http3: invalid use of RequestStream.Read before ReadResponse") _, err = str.Write([]byte{0}) require.EqualError(t, err, "http3: invalid use of RequestStream.Write before SendRequestHeader") // calling ReadResponse before SendRequestHeader is not valid _, err = str.ReadResponse() require.EqualError(t, err, "http3: invalid use of RequestStream.ReadResponse before SendRequestHeader") // SendRequestHeader can't be used for requests that have a request body require.EqualError(t, str.SendRequestHeader( httptest.NewRequest(http.MethodGet, "https://quic-go.net", strings.NewReader("foobar")), ), "http3: invalid use of RequestStream.SendRequestHeader with a request that has a request body", ) req := httptest.NewRequest(http.MethodGet, "https://quic-go.net", nil) qstr.EXPECT().Write(gomock.Any()).AnyTimes() require.NoError(t, str.SendRequestHeader(req)) // duplicate calls are not allowed require.EqualError(t, str.SendRequestHeader(req), "http3: invalid duplicate use of RequestStream.SendRequestHeader") buf := bytes.NewBuffer(encodeResponse(t, http.StatusOK)) buf.Write((&dataFrame{Length: 6}).Append(nil)) buf.Write([]byte("foobar")) qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() rsp, err := str.ReadResponse() require.NoError(t, err) require.Equal(t, http.StatusOK, rsp.StatusCode) b := make([]byte, 10) n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, []byte("foobar"), b[:n]) } quic-go-0.59.0/http3/trace.go000066400000000000000000000061711513066070600156510ustar00rootroot00000000000000package http3 import ( "crypto/tls" "net" "net/http/httptrace" "net/textproto" "time" "github.com/quic-go/quic-go" ) func traceGetConn(trace *httptrace.ClientTrace, hostPort string) { if trace != nil && trace.GetConn != nil { trace.GetConn(hostPort) } } // fakeConn is a wrapper for quic.EarlyConnection // because the quic connection does not implement net.Conn. type fakeConn struct { conn *quic.Conn } func (c *fakeConn) Close() error { panic("connection operation prohibited") } func (c *fakeConn) Read(p []byte) (int, error) { panic("connection operation prohibited") } func (c *fakeConn) Write(p []byte) (int, error) { panic("connection operation prohibited") } func (c *fakeConn) SetDeadline(t time.Time) error { panic("connection operation prohibited") } func (c *fakeConn) SetReadDeadline(t time.Time) error { panic("connection operation prohibited") } func (c *fakeConn) SetWriteDeadline(t time.Time) error { panic("connection operation prohibited") } func (c *fakeConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } func (c *fakeConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } func traceGotConn(trace *httptrace.ClientTrace, conn *quic.Conn, reused bool) { if trace != nil && trace.GotConn != nil { trace.GotConn(httptrace.GotConnInfo{ Conn: &fakeConn{conn: conn}, Reused: reused, }) } } func traceGotFirstResponseByte(trace *httptrace.ClientTrace) { if trace != nil && trace.GotFirstResponseByte != nil { trace.GotFirstResponseByte() } } func traceGot1xxResponse(trace *httptrace.ClientTrace, code int, header textproto.MIMEHeader) { if trace != nil && trace.Got1xxResponse != nil { trace.Got1xxResponse(code, header) } } func traceGot100Continue(trace *httptrace.ClientTrace) { if trace != nil && trace.Got100Continue != nil { trace.Got100Continue() } } func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { return trace != nil && trace.WroteHeaderField != nil } func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { if trace != nil && trace.WroteHeaderField != nil { trace.WroteHeaderField(k, []string{v}) } } func traceWroteHeaders(trace *httptrace.ClientTrace) { if trace != nil && trace.WroteHeaders != nil { trace.WroteHeaders() } } func traceWroteRequest(trace *httptrace.ClientTrace, err error) { if trace != nil && trace.WroteRequest != nil { trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) } } func traceConnectStart(trace *httptrace.ClientTrace, network, addr string) { if trace != nil && trace.ConnectStart != nil { trace.ConnectStart(network, addr) } } func traceConnectDone(trace *httptrace.ClientTrace, network, addr string, err error) { if trace != nil && trace.ConnectDone != nil { trace.ConnectDone(network, addr, err) } } func traceTLSHandshakeStart(trace *httptrace.ClientTrace) { if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } } func traceTLSHandshakeDone(trace *httptrace.ClientTrace, state tls.ConnectionState, err error) { if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(state, err) } } quic-go-0.59.0/http3/transport.go000066400000000000000000000355151513066070600166130ustar00rootroot00000000000000package http3 import ( "context" "crypto/tls" "errors" "fmt" "io" "log/slog" "net" "net/http" "net/http/httptrace" "net/url" "strings" "sync" "sync/atomic" "golang.org/x/net/http/httpguts" "github.com/quic-go/quic-go" ) // Settings are HTTP/3 settings that apply to the underlying connection. type Settings struct { // Support for HTTP/3 datagrams (RFC 9297) EnableDatagrams bool // Extended CONNECT, RFC 9220 EnableExtendedConnect bool // Other settings, defined by the application Other map[uint64]uint64 } // RoundTripOpt are options for the Transport.RoundTripOpt method. type RoundTripOpt struct { // OnlyCachedConn controls whether the Transport may create a new QUIC connection. // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool } type clientConn interface { OpenRequestStream(context.Context) (*RequestStream, error) RoundTrip(*http.Request) (*http.Response, error) handleUnidirectionalStream(*quic.ReceiveStream) } type roundTripperWithCount struct { cancel context.CancelFunc dialing chan struct{} // closed as soon as quic.Dial(Early) returned dialErr error conn *quic.Conn clientConn clientConn useCount atomic.Int64 } func (r *roundTripperWithCount) Close() error { r.cancel() <-r.dialing if r.conn != nil { return r.conn.CloseWithError(0, "") } return nil } // Transport implements the http.RoundTripper interface type Transport struct { // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config // QUICConfig is the quic.Config used for dialing new connections. // If nil, reasonable default values will be used. QUICConfig *quic.Config // Dial specifies an optional dial function for creating QUIC // connections for requests. // If Dial is nil, a UDPConn will be created at the first request // and will be reused for subsequent connections to other servers. Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) // Enable support for HTTP/3 datagrams (RFC 9297). // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. EnableDatagrams bool // Additional HTTP/3 settings. // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 // MaxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. // Zero means to use a default limit. MaxResponseHeaderBytes int // DisableCompression, if true, prevents the Transport from requesting compression with an // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value. // If the Transport requests gzip on its own and gets a gzipped response, it's transparently // decoded in the Response.Body. // However, if the user explicitly requested gzip it is not automatically uncompressed. DisableCompression bool Logger *slog.Logger mutex sync.Mutex initOnce sync.Once initErr error newClientConn func(*quic.Conn) clientConn clients map[string]*roundTripperWithCount transport *quic.Transport closed bool } var ( _ http.RoundTripper = &Transport{} _ io.Closer = &Transport{} ) var ( // ErrNoCachedConn is returned when Transport.OnlyCachedConn is set ErrNoCachedConn = errors.New("http3: no cached connection was available") // ErrTransportClosed is returned when attempting to use a closed Transport ErrTransportClosed = errors.New("http3: transport is closed") ) func (t *Transport) init() error { if t.newClientConn == nil { t.newClientConn = func(conn *quic.Conn) clientConn { return newClientConn( conn, t.EnableDatagrams, t.AdditionalSettings, t.MaxResponseHeaderBytes, t.DisableCompression, t.Logger, ) } } if t.QUICConfig == nil { t.QUICConfig = defaultQuicConfig.Clone() t.QUICConfig.EnableDatagrams = t.EnableDatagrams } if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams { return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") } if len(t.QUICConfig.Versions) == 0 { t.QUICConfig = t.QUICConfig.Clone() t.QUICConfig.Versions = []quic.Version{quic.SupportedVersions()[0]} } if len(t.QUICConfig.Versions) != 1 { return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") } if t.QUICConfig.MaxIncomingStreams == 0 { t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams } if t.Dial == nil { udpConn, err := net.ListenUDP("udp", nil) if err != nil { return err } t.transport = &quic.Transport{Conn: udpConn} } return nil } // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { rsp, err := t.roundTripOpt(req, opt) if err != nil { if req.Body != nil { req.Body.Close() } return nil, err } return rsp, nil } func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { t.initOnce.Do(func() { t.initErr = t.init() }) if t.initErr != nil { return nil, t.initErr } if req.URL == nil { return nil, errors.New("http3: nil Request.URL") } if req.URL.Scheme != "https" { return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) } if req.URL.Host == "" { return nil, errors.New("http3: no Host in request URL") } if req.Header == nil { return nil, errors.New("http3: nil Request.Header") } if req.Method != "" && !validMethod(req.Method) { return nil, fmt.Errorf("http3: invalid method %q", req.Method) } for k, vv := range req.Header { if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("http3: invalid http header field name %q", k) } for _, v := range vv { if !httpguts.ValidHeaderFieldValue(v) { return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) } } } return t.doRoundTripOpt(req, opt, false) } func (t *Transport) doRoundTripOpt(req *http.Request, opt RoundTripOpt, isRetried bool) (*http.Response, error) { hostname := authorityAddr(hostnameFromURL(req.URL)) trace := httptrace.ContextClientTrace(req.Context()) traceGetConn(trace, hostname) cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn) if err != nil { return nil, err } select { case <-cl.dialing: case <-req.Context().Done(): return nil, context.Cause(req.Context()) } if cl.dialErr != nil { t.removeClient(hostname) return nil, cl.dialErr } defer cl.useCount.Add(-1) traceGotConn(trace, cl.conn, isReused) rsp, err := cl.clientConn.RoundTrip(req) if err != nil { // request aborted due to context cancellation select { case <-req.Context().Done(): return nil, err default: } if isRetried { return nil, err } t.removeClient(hostname) req, err = canRetryRequest(err, req) if err != nil { return nil, err } return t.doRoundTripOpt(req, opt, true) } return rsp, nil } func canRetryRequest(err error, req *http.Request) (*http.Request, error) { // error occurred while opening the stream, we can be sure that the request wasn't sent out var connErr *errConnUnusable if errors.As(err, &connErr) { return req, nil } // If the request stream is reset, we can only be sure that the request wasn't processed // if the error code is H3_REQUEST_REJECTED. var e *Error if !errors.As(err, &e) || e.ErrorCode != ErrCodeRequestRejected { return nil, err } // if the body is nil (or http.NoBody), it's safe to reuse this request and its body if req.Body == nil || req.Body == http.NoBody { return req, nil } // if the request body can be reset back to its original state via req.GetBody, do that if req.GetBody != nil { newBody, err := req.GetBody() if err != nil { return nil, err } reqCopy := *req reqCopy.Body = newBody req = &reqCopy return &reqCopy, nil } return nil, fmt.Errorf("http3: Transport: cannot retry err [%w] after Request.Body was written; define Request.GetBody to avoid this error", err) } // RoundTrip does a round trip. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return t.RoundTripOpt(req, RoundTripOpt{}) } func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { t.mutex.Lock() defer t.mutex.Unlock() if t.closed { return nil, false, ErrTransportClosed } if t.clients == nil { t.clients = make(map[string]*roundTripperWithCount) } cl, ok := t.clients[hostname] if !ok { if onlyCached { return nil, false, ErrNoCachedConn } ctx, cancel := context.WithCancel(ctx) cl = &roundTripperWithCount{ dialing: make(chan struct{}), cancel: cancel, } go func() { defer close(cl.dialing) defer cancel() conn, rt, err := t.dial(ctx, hostname) if err != nil { cl.dialErr = err return } cl.conn = conn cl.clientConn = rt }() t.clients[hostname] = cl } select { case <-cl.dialing: if cl.dialErr != nil { delete(t.clients, hostname) return nil, false, cl.dialErr } select { case <-cl.conn.HandshakeComplete(): isReused = true default: } default: } cl.useCount.Add(1) return cl, isReused, nil } func (t *Transport) dial(ctx context.Context, hostname string) (*quic.Conn, clientConn, error) { var tlsConf *tls.Config if t.TLSClientConfig == nil { tlsConf = &tls.Config{} } else { tlsConf = t.TLSClientConfig.Clone() } if tlsConf.ServerName == "" { sni, _, err := net.SplitHostPort(hostname) if err != nil { // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. sni = hostname } tlsConf.ServerName = sni } // Replace existing ALPNs by H3 tlsConf.NextProtos = []string{NextProtoH3} dial := t.Dial if dial == nil { dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { network := "udp" udpAddr, err := t.resolveUDPAddr(ctx, network, addr) if err != nil { return nil, err } trace := httptrace.ContextClientTrace(ctx) traceConnectStart(trace, network, udpAddr.String()) traceTLSHandshakeStart(trace) conn, err := t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) var state tls.ConnectionState if conn != nil { state = conn.ConnectionState().TLS } traceTLSHandshakeDone(trace, state, err) traceConnectDone(trace, network, udpAddr.String(), err) return conn, err } } conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig) if err != nil { return nil, nil, err } clientConn := t.newClientConn(conn) go func() { for { str, err := conn.AcceptUniStream(context.Background()) if err != nil { return } go clientConn.handleUnidirectionalStream(str) } }() return conn, clientConn, nil } func (t *Transport) resolveUDPAddr(ctx context.Context, network, addr string) (*net.UDPAddr, error) { host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, err } port, err := net.LookupPort(network, portStr) if err != nil { return nil, err } resolver := net.DefaultResolver ipAddrs, err := resolver.LookupIPAddr(ctx, host) if err != nil { return nil, err } addrs := addrList(ipAddrs) ip := addrs.forResolve(network, addr) return &net.UDPAddr{IP: ip.IP, Port: port, Zone: ip.Zone}, nil } func (t *Transport) removeClient(hostname string) { t.mutex.Lock() defer t.mutex.Unlock() if t.clients == nil { return } delete(t.clients, hostname) } // NewClientConn creates a new HTTP/3 client connection on top of a QUIC connection. // Most users should use RoundTrip instead of creating a connection directly. // Specifically, it is not needed to perform GET, POST, HEAD and CONNECT requests. // // Obtaining a ClientConn is only needed for more advanced use cases, such as // using Extended CONNECT for WebTransport or the various MASQUE protocols. func (t *Transport) NewClientConn(conn *quic.Conn) *ClientConn { c := newClientConn( conn, t.EnableDatagrams, t.AdditionalSettings, t.MaxResponseHeaderBytes, t.DisableCompression, t.Logger, ) go func() { for { str, err := conn.AcceptUniStream(context.Background()) if err != nil { return } go c.handleUnidirectionalStream(str) } }() return c } // NewRawClientConn creates a new low-level HTTP/3 client connection on top of a QUIC connection. // Unlike NewClientConn, the returned RawClientConn allows the application to take control // of the stream accept loops, by calling HandleUnidirectionalStream for incoming unidirectional // streams and HandleBidirectionalStream for incoming bidirectional streams. func (t *Transport) NewRawClientConn(conn *quic.Conn) *RawClientConn { return &RawClientConn{ ClientConn: newClientConn( conn, t.EnableDatagrams, t.AdditionalSettings, t.MaxResponseHeaderBytes, t.DisableCompression, t.Logger, ), } } // Close closes the QUIC connections that this Transport has used. // A Transport cannot be used after it has been closed. func (t *Transport) Close() error { t.mutex.Lock() defer t.mutex.Unlock() for _, cl := range t.clients { if err := cl.Close(); err != nil { return err } } t.clients = nil if t.transport != nil { if err := t.transport.Close(); err != nil { return err } if err := t.transport.Conn.Close(); err != nil { return err } t.transport = nil } t.closed = true return nil } func hostnameFromURL(url *url.URL) string { if url != nil { return url.Host } return "" } func validMethod(method string) bool { /* Method = "OPTIONS" ; Section 9.2 | "GET" ; Section 9.3 | "HEAD" ; Section 9.4 | "POST" ; Section 9.5 | "PUT" ; Section 9.6 | "DELETE" ; Section 9.7 | "TRACE" ; Section 9.8 | "CONNECT" ; Section 9.9 | extension-method extension-method = token token = 1* */ return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 } // copied from net/http/http.go func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } // CloseIdleConnections closes any QUIC connections in the transport's pool that are currently idle. // An idle connection is one that was previously used for requests but is now sitting unused. // This method does not interrupt any connections currently in use. // It also does not affect connections obtained via NewClientConn. func (t *Transport) CloseIdleConnections() { t.mutex.Lock() defer t.mutex.Unlock() for hostname, cl := range t.clients { if cl.useCount.Load() == 0 { cl.Close() delete(t.clients, hostname) } } } quic-go-0.59.0/http3/transport_test.go000066400000000000000000000406101513066070600176420ustar00rootroot00000000000000package http3 import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type mockBody struct { reader bytes.Reader readErr error closeErr error closed bool } // make sure the mockBody can be used as a http.Request.Body var _ io.ReadCloser = &mockBody{} func (m *mockBody) Read(p []byte) (int, error) { if m.readErr != nil { return 0, m.readErr } return m.reader.Read(p) } func (m *mockBody) SetData(data []byte) { m.reader = *bytes.NewReader(data) } func (m *mockBody) Close() error { m.closed = true return m.closeErr } func TestRequestValidation(t *testing.T) { var tr Transport for _, tt := range []struct { name string req *http.Request expectedErr string expectedErrContains string }{ { name: "plain HTTP", req: httptest.NewRequest(http.MethodGet, "http://www.example.org/", nil), expectedErr: "http3: unsupported protocol scheme: http", }, { name: "missing URL", req: func() *http.Request { r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil) r.URL = nil return r }(), expectedErr: "http3: nil Request.URL", }, { name: "missing URL Host", req: func() *http.Request { r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil) r.URL.Host = "" return r }(), expectedErr: "http3: no Host in request URL", }, { name: "missing header", req: func() *http.Request { r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil) r.Header = nil return r }(), expectedErr: "http3: nil Request.Header", }, { name: "invalid header name", req: func() *http.Request { r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil) r.Header.Add("foobär", "value") return r }(), expectedErr: "http3: invalid http header field name \"foobär\"", }, { name: "invalid header value", req: func() *http.Request { r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil) r.Header.Add("foo", string([]byte{0x7})) return r }(), expectedErrContains: "http3: invalid http header field value", }, { name: "invalid method", req: func() *http.Request { r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil) r.Method = "foobär" return r }(), expectedErr: "http3: invalid method \"foobär\"", }, } { t.Run(tt.name, func(t *testing.T) { tt.req.Body = &mockBody{} _, err := tr.RoundTrip(tt.req) if tt.expectedErr != "" { require.EqualError(t, err, tt.expectedErr) } if tt.expectedErrContains != "" { require.Error(t, err) require.Contains(t, err.Error(), tt.expectedErrContains) } require.True(t, tt.req.Body.(*mockBody).closed) }) } } func TestTransportDialHostname(t *testing.T) { type hostnameConfig struct { dialHostname string tlsServerName string } hostnameChan := make(chan hostnameConfig, 1) tr := &Transport{ Dial: func(_ context.Context, hostname string, tlsConf *tls.Config, _ *quic.Config) (*quic.Conn, error) { hostnameChan <- hostnameConfig{ dialHostname: hostname, tlsServerName: tlsConf.ServerName, } return nil, errors.New("test done") }, } t.Run("port set", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "https://quic-go.net:1234", nil) _, err := tr.RoundTripOpt(req, RoundTripOpt{}) require.EqualError(t, err, "test done") select { case c := <-hostnameChan: require.Equal(t, "quic-go.net:1234", c.dialHostname) require.Equal(t, "quic-go.net", c.tlsServerName) case <-time.After(1 * time.Second): t.Fatal("timeout") } }) // if the request doesn't have a port, the default port is used t.Run("port not set", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "https://quic-go.net", nil) _, err := tr.RoundTripOpt(req, RoundTripOpt{}) require.EqualError(t, err, "test done") select { case c := <-hostnameChan: require.Equal(t, "quic-go.net:443", c.dialHostname) require.Equal(t, "quic-go.net", c.tlsServerName) case <-time.After(1 * time.Second): t.Fatal("timeout") } }) } func TestTransportDatagrams(t *testing.T) { // if the default quic.Config is used, the transport automatically enables QUIC datagrams t.Run("default quic.Config", func(t *testing.T) { tr := &Transport{ EnableDatagrams: true, Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (*quic.Conn, error) { require.True(t, quicConf.EnableDatagrams) return nil, assert.AnError }, } req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) _, err := tr.RoundTripOpt(req, RoundTripOpt{}) require.ErrorIs(t, err, assert.AnError) }) // if a custom quic.Config is used, the transport just checks that QUIC datagrams are enabled t.Run("custom quic.Config", func(t *testing.T) { tr := &Transport{ EnableDatagrams: true, QUICConfig: &quic.Config{EnableDatagrams: false}, Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (*quic.Conn, error) { t.Fatal("dial should not be called") return nil, nil }, } req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) _, err := tr.RoundTripOpt(req, RoundTripOpt{}) require.EqualError(t, err, "HTTP Datagrams enabled, but QUIC Datagrams disabled") }) } func TestTransportMultipleQUICVersions(t *testing.T) { qconf := &quic.Config{ Versions: []quic.Version{quic.Version2, quic.Version1}, } tr := &Transport{QUICConfig: qconf} req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) _, err := tr.RoundTrip(req) require.EqualError(t, err, "can only use a single QUIC version for dialing a HTTP/3 connection") } func TestTransportConnectionReuse(t *testing.T) { conn, _ := newConnPair(t) mockCtrl := gomock.NewController(t) cl := NewMockClientConn(mockCtrl) var dialCount int tr := &Transport{ Dial: func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) { dialCount++ return conn, nil }, newClientConn: func(*quic.Conn) clientConn { return cl }, } req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil) // if OnlyCachedConn is set, no connection is dialed _, err := tr.RoundTripOpt(req1, RoundTripOpt{OnlyCachedConn: true}) require.ErrorIs(t, err, ErrNoCachedConn) require.Zero(t, dialCount) // the first request establishes the connection... cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil) rsp, err := tr.RoundTrip(req1) require.NoError(t, err) require.Equal(t, req1, rsp.Request) require.Equal(t, 1, dialCount) // ... which is then used for the second request req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil) cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) rsp, err = tr.RoundTrip(req2) require.NoError(t, err) require.Equal(t, req2, rsp.Request) require.Equal(t, 1, dialCount) } // Requests reuse the same underlying QUIC connection. // If a request experiences an error, the behavior depends on the nature of that error. func TestTransportConnectionRedial(t *testing.T) { nonRetryableReq := httptest.NewRequest( http.MethodGet, "https://quic-go.org", strings.NewReader("foobar"), ) require.Nil(t, nonRetryableReq.GetBody) retryableReq := nonRetryableReq.Clone(context.Background()) retryableReq.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader("foobaz")), nil } // If the error occurs when opening the stream, it is safe to retry the request: // We can be certain that it wasn't sent out (not even partially). t.Run("error when opening the stream", func(t *testing.T) { require.NoError(t, testTransportConnectionRedial(t, nonRetryableReq, &errConnUnusable{errors.New("test")}, "foobar", true), ) }) // If the error occurs when opening the stream, it is safe to retry the request: // We can be certain that it wasn't sent out (not even partially). t.Run("non-retryable request error after opening the stream", func(t *testing.T) { require.ErrorIs(t, testTransportConnectionRedial(t, nonRetryableReq, assert.AnError, "foobar", false), assert.AnError, ) }) t.Run("retryable request after opening the stream", func(t *testing.T) { require.ErrorIs(t, testTransportConnectionRedial(t, retryableReq, assert.AnError, "", false), assert.AnError, ) }) t.Run("retryable request after H3_REQUEST_REJECTED", func(t *testing.T) { require.NoError(t, testTransportConnectionRedial(t, retryableReq, &Error{ErrorCode: ErrCodeRequestRejected}, "foobaz", true, ), ) }) t.Run("retryable request where GetBody returns an error", func(t *testing.T) { req := nonRetryableReq.Clone(context.Background()) req.GetBody = func() (io.ReadCloser, error) { return nil, assert.AnError } require.ErrorIs(t, testTransportConnectionRedial(t, req, &Error{ErrorCode: ErrCodeRequestRejected}, "", false), assert.AnError, ) }) } func testTransportConnectionRedial(t *testing.T, req *http.Request, roundtripErr error, expectedBody string, expectRedial bool) error { conn, _ := newConnPair(t) mockCtrl := gomock.NewController(t) cl := NewMockClientConn(mockCtrl) var dialCount int tr := &Transport{ Dial: func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) { dialCount++ return conn, nil }, newClientConn: func(*quic.Conn) clientConn { return cl }, } var body string cl.EXPECT().RoundTrip(req).Return(nil, roundtripErr) if expectRedial { cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) { b, err := io.ReadAll(r.Body) if err != nil { panic(fmt.Sprintf("reading body failed: %v", err)) } body = string(b) return &http.Response{Request: req}, nil }) } _, err := tr.RoundTrip(req) if !expectRedial { assert.Equal(t, 1, dialCount) } else { assert.Equal(t, 2, dialCount) assert.Equal(t, expectedBody, body) } return err } func TestTransportRequestContextCancellation(t *testing.T) { mockCtrl := gomock.NewController(t) cl := NewMockClientConn(mockCtrl) conn, _ := newConnPair(t) var dialCount int tr := &Transport{ Dial: func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) { dialCount++ return conn, nil }, newClientConn: func(*quic.Conn) clientConn { return cl }, } // the first request succeeds req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil) cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil) rsp, err := tr.RoundTrip(req1) require.NoError(t, err) require.Equal(t, req1, rsp.Request) require.Equal(t, 1, dialCount) // the second request reuses the QUIC connection, and runs into the cancelled context req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil) ctx, cancel := context.WithCancel(context.Background()) req2 = req2.WithContext(ctx) cl.EXPECT().RoundTrip(req2).DoAndReturn( func(r *http.Request) (*http.Response, error) { cancel() return nil, context.Canceled }, ) _, err = tr.RoundTrip(req2) require.ErrorIs(t, err, context.Canceled) require.Equal(t, 1, dialCount) // the next request reuses the QUIC connection req3 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil) cl.EXPECT().RoundTrip(req3).Return(&http.Response{Request: req3}, nil) rsp, err = tr.RoundTrip(req3) require.NoError(t, err) require.Equal(t, req3, rsp.Request) require.Equal(t, 1, dialCount) } func TestTransportConnetionRedialHandshakeError(t *testing.T) { mockCtrl := gomock.NewController(t) cl := NewMockClientConn(mockCtrl) conn, _ := newConnPair(t) var dialCount int tr := &Transport{ Dial: func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) { dialCount++ if dialCount == 1 { return nil, assert.AnError } return conn, nil }, newClientConn: func(*quic.Conn) clientConn { return cl }, } req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil) _, err := tr.RoundTrip(req1) require.ErrorIs(t, err, assert.AnError) require.Equal(t, 1, dialCount) req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil) cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) rsp, err := tr.RoundTrip(req2) require.NoError(t, err) require.Equal(t, req2, rsp.Request) require.Equal(t, 2, dialCount) } func TestTransportCloseEstablishedConnections(t *testing.T) { mockCtrl := gomock.NewController(t) conn, _ := newConnPair(t) tr := &Transport{ Dial: func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) { return conn, nil }, newClientConn: func(*quic.Conn) clientConn { cl := NewMockClientConn(mockCtrl) cl.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{}, nil) return cl }, } req := httptest.NewRequest(http.MethodGet, "https://quic-go.net/foobar.html", nil) _, err := tr.RoundTrip(req) require.NoError(t, err) require.NoError(t, tr.Close()) select { case <-conn.Context().Done(): case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportCloseInFlightDials(t *testing.T) { tr := &Transport{ Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (*quic.Conn, error) { var err error select { case <-ctx.Done(): err = ctx.Err() case <-time.After(time.Second): err = errors.New("timeout") } return nil, err }, } req := httptest.NewRequest(http.MethodGet, "https://quic-go.net/foobar.html", nil) errChan := make(chan error, 1) go func() { _, err := tr.RoundTrip(req) errChan <- err }() select { case err := <-errChan: t.Fatalf("received unexpected error: %v", err) case <-time.After(scaleDuration(10 * time.Millisecond)): } require.NoError(t, tr.Close()) select { case err := <-errChan: require.ErrorIs(t, err, context.Canceled) case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportCloseIdleConnections(t *testing.T) { mockCtrl := gomock.NewController(t) conn1, _ := newConnPair(t) conn2, _ := newConnPair(t) roundTripCalled := make(chan struct{}) tr := &Transport{ Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (*quic.Conn, error) { switch hostname { case "site1.com:443": return conn1, nil case "site2.com:443": return conn2, nil default: t.Fatal("unexpected hostname") return nil, errors.New("unexpected hostname") } }, newClientConn: func(*quic.Conn) clientConn { cl := NewMockClientConn(mockCtrl) cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) { roundTripCalled <- struct{}{} <-r.Context().Done() return nil, nil }) return cl }, } req1 := httptest.NewRequest(http.MethodGet, "https://site1.com", nil) req2 := httptest.NewRequest(http.MethodGet, "https://site2.com", nil) require.NotEqual(t, req1.Host, req2.Host) ctx1, cancel1 := context.WithCancel(context.Background()) ctx2, cancel2 := context.WithCancel(context.Background()) req1 = req1.WithContext(ctx1) req2 = req2.WithContext(ctx2) reqFinished := make(chan struct{}) go func() { tr.RoundTrip(req1) reqFinished <- struct{}{} }() go func() { tr.RoundTrip(req2) reqFinished <- struct{}{} }() <-roundTripCalled <-roundTripCalled // Both two requests are started. cancel1() <-reqFinished // req1 is finished tr.CloseIdleConnections() select { case <-conn1.Context().Done(): case <-time.After(time.Second): t.Fatal("timeout") } cancel2() <-reqFinished // all requests are finished tr.CloseIdleConnections() select { case <-conn2.Context().Done(): case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportClose(t *testing.T) { mockCtrl := gomock.NewController(t) conn, _ := newConnPair(t) tr := &Transport{ Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { return conn, nil }, newClientConn: func(*quic.Conn) clientConn { cl := NewMockClientConn(mockCtrl) cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, nil) return cl }, } req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) require.NoError(t, err) _, err = tr.RoundTrip(req) require.NoError(t, err) require.NoError(t, tr.Close()) _, err = tr.RoundTrip(req) require.ErrorIs(t, err, ErrTransportClosed) } quic-go-0.59.0/integrationtests/000077500000000000000000000000001513066070600165635ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/gomodvendor/000077500000000000000000000000001513066070600211065ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/gomodvendor/.gitignore000066400000000000000000000000101513066070600230650ustar00rootroot00000000000000vendor/ quic-go-0.59.0/integrationtests/gomodvendor/go.mod000066400000000000000000000006531513066070600222200ustar00rootroot00000000000000module test go 1.24 // The version doesn't matter here, as we're replacing it with the currently checked out code anyway. require github.com/quic-go/quic-go v0.21.0 require ( github.com/quic-go/qpack v0.6.0 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect ) replace github.com/quic-go/quic-go => ../../ quic-go-0.59.0/integrationtests/gomodvendor/go.sum000066400000000000000000000031141513066070600222400ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= quic-go-0.59.0/integrationtests/gomodvendor/main.go000066400000000000000000000003011513066070600223530ustar00rootroot00000000000000package main import "github.com/quic-go/quic-go/http3" // The contents of this script don't matter. // We just need to make sure that quic-go is imported. func main() { _ = http3.Server{} } quic-go-0.59.0/integrationtests/self/000077500000000000000000000000001513066070600175145ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/self/benchmark_test.go000066400000000000000000000065401513066070600230410ustar00rootroot00000000000000package self_test import ( "bytes" "context" "fmt" "io" "testing" "time" "github.com/quic-go/quic-go" "github.com/stretchr/testify/require" ) func BenchmarkHandshake(b *testing.B) { b.ReportAllocs() ln, err := quic.Listen(newUDPConnLocalhost(b), tlsConfig, nil) require.NoError(b, err) defer ln.Close() connChan := make(chan *quic.Conn, 1) go func() { for { conn, err := ln.Accept(context.Background()) if err != nil { return } connChan <- conn } }() tr := &quic.Transport{Conn: newUDPConnLocalhost(b)} defer tr.Close() for b.Loop() { c, err := tr.Dial(context.Background(), ln.Addr(), tlsClientConfig, nil) if err != nil { b.Fatalf("error dialing: %v", err) } serverConn := <-connChan serverConn.CloseWithError(0, "") c.CloseWithError(0, "") } } func BenchmarkStreamChurn(b *testing.B) { b.ReportAllocs() ln, err := quic.Listen(newUDPConnLocalhost(b), tlsConfig, &quic.Config{MaxIncomingStreams: 1e10}) require.NoError(b, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(b), ln.Addr(), tlsClientConfig, nil) require.NoError(b, err) defer conn.CloseWithError(0, "") serverConn, err := ln.Accept(context.Background()) require.NoError(b, err) defer serverConn.CloseWithError(0, "") go func() { for { str, err := serverConn.AcceptStream(context.Background()) if err != nil { return } str.Close() } }() for b.Loop() { str, err := conn.OpenStreamSync(context.Background()) if err != nil { b.Fatalf("error opening stream: %v", err) } if err := str.Close(); err != nil { b.Fatalf("error closing stream: %v", err) } } } func BenchmarkTransfer(b *testing.B) { b.Run(fmt.Sprintf("%d kb", len(PRData)/1024), func(b *testing.B) { benchmarkTransfer(b, PRData) }) b.Run(fmt.Sprintf("%d kb", len(PRDataLong)/1024), func(b *testing.B) { benchmarkTransfer(b, PRDataLong) }) } func benchmarkTransfer(b *testing.B, data []byte) { b.ReportAllocs() ln, err := quic.Listen(newUDPConnLocalhost(b), tlsConfig, nil) require.NoError(b, err) defer ln.Close() connChan := make(chan *quic.Conn, 1) go func() { for { conn, err := ln.Accept(context.Background()) if err != nil { return } connChan <- conn str, err := conn.OpenUniStream() if err != nil { b.Logf("error opening stream: %v", err) return } if _, err := str.Write(data); err != nil { b.Logf("error writing data: %v", err) return } if err := str.Close(); err != nil { b.Logf("error closing stream: %v", err) return } } }() tr := &quic.Transport{Conn: newUDPConnLocalhost(b)} defer tr.Close() buf := make([]byte, len(data)) for b.Loop() { c, err := tr.Dial(context.Background(), ln.Addr(), tlsClientConfig, nil) if err != nil { b.Fatalf("error dialing: %v", err) } str, err := c.AcceptUniStream(context.Background()) if err != nil { b.Fatalf("error accepting stream: %v", err) } if _, err := io.ReadFull(str, buf); err != nil { b.Fatalf("error reading data: %v", err) } if _, err := str.Read([]byte{0}); err != io.EOF { b.Fatalf("error reading EOF: %v", err) } if !bytes.Equal(buf, data) { b.Fatalf("data mismatch: got %x, expected %x", buf, data) } serverConn := <-connChan serverConn.CloseWithError(0, "") c.CloseWithError(0, "") } } quic-go-0.59.0/integrationtests/self/cancelation_test.go000066400000000000000000000364151513066070600233730ustar00rootroot00000000000000package self_test import ( "bytes" "context" "errors" "fmt" "io" "math/rand/v2" "sync" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStreamReadCancellation(t *testing.T) { t.Run("immediate", func(t *testing.T) { testStreamCancellation(t, func(str *quic.ReceiveStream) error { str.CancelRead(quic.StreamErrorCode(str.StreamID())) _, err := str.Read([]byte{0}) return err }, nil) }) t.Run("after reading some data", func(t *testing.T) { testStreamCancellation(t, func(str *quic.ReceiveStream) error { length := rand.IntN(len(PRData) - 1) if _, err := io.ReadAll(io.LimitReader(str, int64(length))); err != nil { return fmt.Errorf("reading stream data failed: %w", err) } str.CancelRead(quic.StreamErrorCode(str.StreamID())) _, err := str.Read([]byte{0}) return err }, nil) }) // This test is especially valuable when run with race detector, // see https://github.com/quic-go/quic-go/issues/3239. t.Run("concurrent", func(t *testing.T) { testStreamCancellation(t, func(str *quic.ReceiveStream) error { errChan := make(chan error, 1) go func() { for { if _, err := str.Read(make([]byte, 16)); err != nil { errChan <- err return } time.Sleep(time.Millisecond) } }() done := make(chan struct{}) go func() { defer close(done) str.CancelRead(quic.StreamErrorCode(str.StreamID())) }() timeout := time.After(time.Second) select { case <-done: case <-timeout: return fmt.Errorf("timeout canceling") } select { case err := <-errChan: return err case <-timeout: return fmt.Errorf("timeout canceling") } }, nil) }) } func TestStreamWriteCancellation(t *testing.T) { t.Run("immediate", func(t *testing.T) { testStreamCancellation(t, nil, func(str *quic.SendStream) error { str.CancelWrite(quic.StreamErrorCode(str.StreamID())) _, err := str.Write([]byte{0}) return err }) }) t.Run("after writing some data", func(t *testing.T) { testStreamCancellation(t, nil, func(str *quic.SendStream) error { length := rand.IntN(len(PRData) - 1) if _, err := str.Write(PRData[:length]); err != nil { return fmt.Errorf("writing stream data failed: %w", err) } str.CancelWrite(quic.StreamErrorCode(str.StreamID())) _, err := str.Write([]byte{0}) return err }) }) // This test is especially valuable when run with race detector, // see https://github.com/quic-go/quic-go/issues/3239. t.Run("concurrent", func(t *testing.T) { testStreamCancellation(t, nil, func(str *quic.SendStream) error { errChan := make(chan error, 1) go func() { var offset int for { n, err := str.Write(PRData[offset : offset+128]) if err != nil { errChan <- err return } offset += n time.Sleep(time.Millisecond) } }() done := make(chan struct{}) go func() { defer close(done) str.CancelWrite(quic.StreamErrorCode(str.StreamID())) }() timeout := time.After(time.Second) select { case <-done: case <-timeout: return fmt.Errorf("timeout canceling") } select { case err := <-errChan: return err case <-timeout: return fmt.Errorf("timeout canceling") } }) }) } func TestStreamReadWriteCancellation(t *testing.T) { t.Run("immediate", func(t *testing.T) { testStreamCancellation(t, func(str *quic.ReceiveStream) error { str.CancelRead(quic.StreamErrorCode(str.StreamID())) _, err := str.Read([]byte{0}) return err }, func(str *quic.SendStream) error { str.CancelWrite(quic.StreamErrorCode(str.StreamID())) _, err := str.Write([]byte{0}) return err }, ) }) t.Run("after writing some data", func(t *testing.T) { testStreamCancellation(t, func(str *quic.ReceiveStream) error { length := rand.IntN(len(PRData) - 1) if _, err := io.ReadAll(io.LimitReader(str, int64(length))); err != nil { return fmt.Errorf("reading stream data failed: %w", err) } str.CancelRead(quic.StreamErrorCode(str.StreamID())) _, err := str.Read([]byte{0}) return err }, func(str *quic.SendStream) error { length := rand.IntN(len(PRData) - 1) if _, err := str.Write(PRData[:length]); err != nil { return fmt.Errorf("writing stream data failed: %w", err) } str.CancelWrite(quic.StreamErrorCode(str.StreamID())) _, err := str.Write([]byte{0}) return err }, ) }) } // If readFunc is set, the read side is canceled for 50% of the streams. // If writeFunc is set, the write side is canceled for 50% of the streams. func testStreamCancellation( t *testing.T, readFunc func(str *quic.ReceiveStream) error, writeFunc func(str *quic.SendStream) error, ) { const numStreams = 80 server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(2*time.Second)) defer cancel() conn, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), ) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) type cancellationErr struct { StreamID quic.StreamID Err error } var numCancellations int actions := make([]bool, numStreams) for i := range actions { actions[i] = rand.IntN(2) == 0 if actions[i] { numCancellations++ } } // The server accepts a single connection, and then opens numStreams unidirectional streams. // On each of these streams, it (tries to) write PRData. serverErrChan := make(chan *cancellationErr, numStreams) go func() { for _, doCancel := range actions { str, err := serverConn.OpenUniStreamSync(ctx) if err != nil { serverErrChan <- &cancellationErr{StreamID: protocol.InvalidStreamID, Err: fmt.Errorf("opening stream failed: %w", err)} return } go func() { if writeFunc != nil && doCancel { if err := writeFunc(str); err != nil { serverErrChan <- &cancellationErr{StreamID: str.StreamID(), Err: err} return } serverErrChan <- nil return } defer str.Close() if _, err := str.Write(PRData); err != nil { serverErrChan <- &cancellationErr{StreamID: str.StreamID(), Err: err} return } serverErrChan <- nil }() } }() clientErrChan := make(chan *cancellationErr, numStreams) for _, doCancel := range actions { str, err := conn.AcceptUniStream(ctx) require.NoError(t, err) go func(str *quic.ReceiveStream) { if readFunc != nil && doCancel { if err := readFunc(str); err != nil { clientErrChan <- &cancellationErr{StreamID: str.StreamID(), Err: err} return } } data, err := io.ReadAll(str) if err != nil { clientErrChan <- &cancellationErr{StreamID: str.StreamID(), Err: fmt.Errorf("reading stream data failed: %w", err)} return } if !bytes.Equal(data, PRData) { clientErrChan <- &cancellationErr{StreamID: str.StreamID(), Err: fmt.Errorf("received data mismatch")} return } clientErrChan <- nil }(str) } timeout := time.After(time.Second) var clientErrs, serverErrs int for range numStreams { select { case err := <-serverErrChan: if err != nil { if err.StreamID == protocol.InvalidStreamID { // failed opening a stream require.NoError(t, err.Err) continue } var streamErr *quic.StreamError require.ErrorAs(t, err.Err, &streamErr) assert.Equal(t, streamErr.StreamID, err.StreamID) assert.Equal(t, streamErr.ErrorCode, quic.StreamErrorCode(err.StreamID)) if readFunc != nil && writeFunc == nil { assert.Equal(t, streamErr.Remote, readFunc != nil) } serverErrs++ } case <-timeout: t.Fatalf("timeout") } select { case err := <-clientErrChan: if err != nil { if err.StreamID == protocol.InvalidStreamID { // failed accepting a stream require.NoError(t, err.Err) continue } var streamErr *quic.StreamError require.ErrorAs(t, err.Err, &streamErr) assert.Equal(t, streamErr.StreamID, err.StreamID) assert.Equal(t, streamErr.ErrorCode, quic.StreamErrorCode(err.StreamID)) if readFunc != nil && writeFunc == nil { assert.Equal(t, streamErr.Remote, writeFunc != nil) } clientErrs++ } case <-timeout: t.Fatalf("timeout") } } assert.Equal(t, numCancellations, clientErrs, "client canceled streams") // The server will only count a stream as being reset if it learns about the cancellation // before it finished writing all data. assert.LessOrEqual(t, serverErrs, numCancellations, "server-observed canceled streams") assert.NotZero(t, serverErrs, "server-observed canceled streams") } func TestCancelAcceptStream(t *testing.T) { const numStreams = 30 server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 3}), ) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) defer conn.CloseWithError(0, "") serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(2*time.Second)) defer cancel() ticker := time.NewTicker(5 * time.Millisecond) defer ticker.Stop() for i := 0; i < numStreams; i++ { <-ticker.C str, err := serverConn.OpenUniStreamSync(ctx) if err != nil { serverErrChan <- err return } if _, err := str.Write(PRData); err != nil { serverErrChan <- err return } str.Close() } }() var numToAccept int var counter atomic.Int32 var wg sync.WaitGroup wg.Add(numStreams) for numToAccept < numStreams { ctx, cancel := context.WithCancel(context.Background()) // cancel accepting half of the streams if rand.Int()%2 == 0 { cancel() } else { numToAccept++ defer cancel() } go func() { str, err := conn.AcceptUniStream(ctx) if err != nil { if errors.Is(err, context.Canceled) { counter.Add(1) } return } go func() { data, err := io.ReadAll(str) if err != nil { t.Errorf("ReadAll failed: %v", err) return } if !bytes.Equal(data, PRData) { t.Errorf("received data mismatch") return } wg.Done() }() }() } wg.Wait() count := counter.Load() t.Logf("canceled AcceptStream %d times", count) require.Greater(t, count, int32(numStreams/4)) require.NoError(t, conn.CloseWithError(0, "")) require.NoError(t, server.Close()) require.NoError(t, <-serverErrChan) } func TestCancelOpenStreamSync(t *testing.T) { const ( numStreams = 16 maxIncomingStreams = 4 ) server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() conn, err := quic.Dial( context.Background(), newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: maxIncomingStreams}), ) require.NoError(t, err) msg := make(chan struct{}, 1) serverErrChan := make(chan error, numStreams+1) var numCanceled int serverConn, err := server.Accept(context.Background()) require.NoError(t, err) go func() { defer close(msg) var numOpened int for numOpened < numStreams { ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond)) defer cancel() str, err := serverConn.OpenUniStreamSync(ctx) if err != nil { if !errors.Is(err, context.DeadlineExceeded) { serverErrChan <- err return } numCanceled++ select { case msg <- struct{}{}: default: } continue } numOpened++ go func(str *quic.SendStream) { defer str.Close() if _, err := str.Write(PRData); err != nil { serverErrChan <- err } }(str) } }() clientErrChan := make(chan error, numStreams) for range numStreams { <-msg str, err := conn.AcceptUniStream(context.Background()) require.NoError(t, err) go func(str *quic.ReceiveStream) { data, err := io.ReadAll(str) if err != nil { clientErrChan <- err return } if !bytes.Equal(data, PRData) { clientErrChan <- fmt.Errorf("received data mismatch") return } clientErrChan <- nil }(str) } timeout := time.After(scaleDuration(2 * time.Second)) for range numStreams { select { case err := <-clientErrChan: require.NoError(t, err) case err := <-serverErrChan: require.NoError(t, err) case <-timeout: t.Fatalf("timeout") } } count := numCanceled t.Logf("Canceled OpenStreamSync %d times", count) require.GreaterOrEqual(t, count, numStreams-maxIncomingStreams) require.NoError(t, conn.CloseWithError(0, "")) require.NoError(t, server.Close()) } func TestHeavyStreamCancellation(t *testing.T) { const maxIncomingStreams = 500 server, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}), ) require.NoError(t, err) defer server.Close() var wg sync.WaitGroup wg.Add(2 * 4 * maxIncomingStreams) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) serverConn, err := server.Accept(context.Background()) require.NoError(t, err) handleStream := func(str *quic.Stream) { str.SetDeadline(time.Now().Add(time.Second)) go func() { defer wg.Done() if rand.Int()%2 == 0 { io.ReadAll(str) } }() go func() { defer wg.Done() if rand.Int()%2 == 0 { str.Write([]byte("foobar")) if rand.Int()%2 == 0 { str.Close() } } }() go func() { defer wg.Done() // Make sure we at least send out *something* for the last stream, // otherwise the peer might never receive this anything for this stream. if rand.Int()%2 == 0 || str.StreamID() == 4*(maxIncomingStreams-1) { str.CancelWrite(1234) } }() go func() { defer wg.Done() if rand.Int()%2 == 0 { str.CancelRead(1234) } }() } serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) for { str, err := serverConn.AcceptStream(context.Background()) if err != nil { serverErrChan <- err return } handleStream(str) } }() for i := 0; i < maxIncomingStreams; i++ { str, err := conn.OpenStreamSync(context.Background()) require.NoError(t, err) handleStream(str) } // We don't expect to accept any stream here. // We're just making sure the connection stays open and there's no error. ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) defer cancel() _, err = conn.AcceptStream(ctx) require.ErrorIs(t, err, context.DeadlineExceeded) wg.Wait() require.NoError(t, conn.CloseWithError(0, "")) select { case err := <-serverErrChan: require.IsType(t, &quic.ApplicationError{}, err) case <-time.After(scaleDuration(time.Second)): t.Fatal("timeout waiting for server to stop") } } quic-go-0.59.0/integrationtests/self/close_test.go000066400000000000000000000140101513066070600222030ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "net" "sync" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnectionCloseRetransmission(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} var drop atomic.Bool var mx sync.Mutex var dropped [][]byte n := &simnet.Simnet{ Router: &droppingRouter{Drop: func(p simnet.Packet) bool { shouldDrop := drop.Load() && p.From.String() == serverAddr.String() if shouldDrop { mx.Lock() dropped = append(dropped, p.Data) mx.Unlock() } return shouldDrop }}, } settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} clientConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}, settings) serverConn := n.NewEndpoint(serverAddr, settings) require.NoError(t, n.Start()) defer n.Close() tr := &quic.Transport{Conn: serverConn} defer tr.Close() server, err := tr.Listen( getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, clientConn, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") sconn, err := server.Accept(ctx) require.NoError(t, err) time.Sleep(rtt) drop.Store(true) sconn.CloseWithError(1337, "closing") // send 100 packets for range 100 { str, err := conn.OpenStream() require.NoError(t, err) _, err = str.Write([]byte("foobar")) require.NoError(t, err) // A closed connection will drop packets if a very short queue overflows. // Waiting for one nanosecond makes synctest process the packet before advancing // the synthetic clock. time.Sleep(time.Nanosecond) } time.Sleep(rtt) mx.Lock() defer mx.Unlock() // Expect retransmissions of the CONNECTION_CLOSE for the // 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet) require.Len(t, dropped, 8) // verify all retransmitted packets were identical for i := 1; i < len(dropped); i++ { require.Equal(t, dropped[0], dropped[i]) } }) } func TestDrainServerAcceptQueue(t *testing.T) { server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() dialer := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer dialer.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // fill up the accept queue conns := make([]*quic.Conn, 0, protocol.MaxAcceptQueueSize) for range protocol.MaxAcceptQueueSize { conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) conns = append(conns, conn) } time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued server.Close() for i := range protocol.MaxAcceptQueueSize { c, err := server.Accept(ctx) require.NoError(t, err) // make sure the connection is not closed require.NoError(t, context.Cause(conns[i].Context()), "client connection closed") require.NoError(t, context.Cause(c.Context()), "server connection closed") c.CloseWithError(0, "") } _, err = server.Accept(ctx) require.ErrorIs(t, err, quic.ErrServerClosed) } type brokenConn struct { net.PacketConn broken chan struct{} breakErr atomic.Pointer[error] } func newBrokenConn(conn net.PacketConn) *brokenConn { c := &brokenConn{ PacketConn: conn, broken: make(chan struct{}), } go func() { <-c.broken // make calls to ReadFrom return c.SetDeadline(time.Now()) }() return c } func (c *brokenConn) ReadFrom(b []byte) (int, net.Addr, error) { if err := c.breakErr.Load(); err != nil { return 0, nil, *err } n, addr, err := c.PacketConn.ReadFrom(b) if err != nil { select { case <-c.broken: err = *c.breakErr.Load() default: } } return n, addr, err } func (c *brokenConn) Break(e error) { c.breakErr.Store(&e) close(c.broken) } func TestTransportClose(t *testing.T) { t.Run("Close", func(t *testing.T) { conn := newUDPConnLocalhost(t) testTransportClose(t, conn, func() { conn.Close() }, nil) }) t.Run("connection error", func(t *testing.T) { t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") bc := newBrokenConn(newUDPConnLocalhost(t)) testTransportClose(t, bc, func() { bc.Break(assert.AnError) }, assert.AnError) }) } func testTransportClose(t *testing.T, conn net.PacketConn, closeFn func(), expectedErr error) { server := newUDPConnLocalhost(t) tr := &quic.Transport{Conn: conn} errChan := make(chan error, 1) go func() { _, err := tr.Dial(context.Background(), server.LocalAddr(), &tls.Config{}, getQuicConfig(nil)) errChan <- err }() select { case <-errChan: t.Fatal("didn't expect Dial to return yet") case <-time.After(scaleDuration(10 * time.Millisecond)): } closeFn() select { case err := <-errChan: require.Error(t, err) require.ErrorIs(t, err, quic.ErrTransportClosed) if expectedErr != nil { require.ErrorIs(t, err, expectedErr) } case <-time.After(time.Second): t.Fatal("timeout") } // it's not possible to dial new connections ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) defer cancel() _, err := tr.Dial(ctx, server.LocalAddr(), &tls.Config{}, getQuicConfig(nil)) require.Error(t, err) require.ErrorIs(t, err, quic.ErrTransportClosed) if expectedErr != nil { require.ErrorIs(t, err, expectedErr) } // it's not possible to create new listeners _, err = tr.Listen(&tls.Config{}, nil) require.Error(t, err) require.ErrorIs(t, err, quic.ErrTransportClosed) if expectedErr != nil { require.ErrorIs(t, err, expectedErr) } } quic-go-0.59.0/integrationtests/self/conn_id_test.go000066400000000000000000000104221513066070600225120ustar00rootroot00000000000000package self_test import ( "context" "crypto/rand" "fmt" "io" mrand "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type connIDGenerator struct { Length int } var _ quic.ConnectionIDGenerator = &connIDGenerator{} func (c *connIDGenerator) GenerateConnectionID() (quic.ConnectionID, error) { b := make([]byte, c.Length) if _, err := rand.Read(b); err != nil { return quic.ConnectionID{}, fmt.Errorf("generating conn ID failed: %w", err) } return protocol.ParseConnectionID(b), nil } func (c *connIDGenerator) ConnectionIDLen() int { return c.Length } func randomConnIDLen() int { return 2 + mrand.IntN(19) } func TestConnectionIDsZeroLength(t *testing.T) { testTransferWithConnectionIDs(t, randomConnIDLen(), 0, nil, nil) } func TestConnectionIDsRandomLengths(t *testing.T) { testTransferWithConnectionIDs(t, randomConnIDLen(), randomConnIDLen(), nil, nil) } func TestConnectionIDsCustomGenerator(t *testing.T) { testTransferWithConnectionIDs(t, 0, 0, &connIDGenerator{Length: randomConnIDLen()}, &connIDGenerator{Length: randomConnIDLen()}, ) } // connIDLen is ignored when connIDGenerator is set func testTransferWithConnectionIDs( t *testing.T, serverConnIDLen, clientConnIDLen int, serverConnIDGenerator, clientConnIDGenerator quic.ConnectionIDGenerator, ) { t.Helper() if serverConnIDGenerator != nil { t.Logf("using %d byte connection ID generator for the server", serverConnIDGenerator.ConnectionIDLen()) } else { t.Logf("issuing %d byte connection ID from the server", serverConnIDLen) } if clientConnIDGenerator != nil { t.Logf("using %d byte connection ID generator for the client", clientConnIDGenerator.ConnectionIDLen()) } else { t.Logf("issuing %d byte connection ID from the client", clientConnIDLen) } // setup server serverTr := &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: serverConnIDLen, ConnectionIDGenerator: serverConnIDGenerator, } defer serverTr.Close() addTracer(serverTr) serverCounter, serverTracer := newPacketTracer() ln, err := serverTr.Listen( getTLSConfig(), getQuicConfig(&quic.Config{ Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return serverTracer }, }), ) require.NoError(t, err) // setup client ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var conn *quic.Conn clientCounter, clientTracer := newPacketTracer() clientQUICConf := getQuicConfig(&quic.Config{ Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return clientTracer }, }) if clientConnIDGenerator == nil && clientConnIDLen == 0 { conn, err = quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), clientQUICConf) require.NoError(t, err) } else { clientTr := &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: clientConnIDLen, ConnectionIDGenerator: clientConnIDGenerator, } defer clientTr.Close() addTracer(clientTr) conn, err = clientTr.Dial(ctx, ln.Addr(), getTLSClientConfig(), clientQUICConf) require.NoError(t, err) } serverConn, err := ln.Accept(context.Background()) require.NoError(t, err) serverStr, err := serverConn.OpenStream() require.NoError(t, err) go func() { serverStr.Write(PRData) serverStr.Close() }() str, err := conn.AcceptStream(context.Background()) require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) conn.CloseWithError(0, "") serverConn.CloseWithError(0, "") for _, p := range serverCounter.getRcvdShortHeaderPackets() { expectedLen := serverConnIDLen if serverConnIDGenerator != nil { expectedLen = serverConnIDGenerator.ConnectionIDLen() } if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "server conn length mismatch") { break } } for _, p := range clientCounter.getRcvdShortHeaderPackets() { expectedLen := clientConnIDLen if clientConnIDGenerator != nil { expectedLen = clientConnIDGenerator.ConnectionIDLen() } if !assert.Equal(t, expectedLen, p.hdr.DestConnectionID.Len(), "client conn length mismatch") { break } } } quic-go-0.59.0/integrationtests/self/connection_migration_test.go000066400000000000000000000075431513066070600253230ustar00rootroot00000000000000package self_test import ( "bytes" "context" "errors" "fmt" "io" "net" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/stretchr/testify/require" ) func TestConnectionMigration(t *testing.T) { ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer tr1.Close() tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer tr2.Close() var packetsPath1, packetsPath2 atomic.Int64 const rtt = 5 * time.Millisecond proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(dir quicproxy.Direction, from, to net.Addr, _ []byte) time.Duration { var port int switch dir { case quicproxy.DirectionIncoming: port = from.(*net.UDPAddr).Port case quicproxy.DirectionOutgoing: port = to.(*net.UDPAddr).Port } switch port { case tr1.Conn.LocalAddr().(*net.UDPAddr).Port: packetsPath1.Add(1) case tr2.Conn.LocalAddr().(*net.UDPAddr).Port: packetsPath2.Add(1) default: fmt.Println("address not found", from) } return rtt / 2 }, } require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := tr1.Dial(ctx, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") sconn, err := ln.Accept(ctx) require.NoError(t, err) defer sconn.CloseWithError(0, "") sendAndReceiveFile := func(t *testing.T) { t.Helper() str, err := conn.OpenUniStream() require.NoError(t, err) errChan := make(chan error, 1) go func() { defer close(errChan) sstr, err := sconn.AcceptUniStream(ctx) if err != nil { errChan <- fmt.Errorf("accepting stream: %w", err) return } data, err := io.ReadAll(sstr) if err != nil { errChan <- fmt.Errorf("reading stream data: %w", err) return } if !bytes.Equal(data, PRData) { errChan <- errors.New("unexpected data") } }() _, err = str.Write(PRData) require.NoError(t, err) require.NoError(t, str.Close()) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timed out waiting for data") } } sendAndReceiveFile(t) // stream 2 require.NotZero(t, packetsPath1.Load()) require.Zero(t, packetsPath2.Load()) // probing the path causes a few packets to be sent on path 2 path, err := conn.AddPath(tr2) require.NoError(t, err) require.ErrorIs(t, path.Switch(), quic.ErrPathNotValidated) require.NoError(t, path.Probe(ctx)) require.Less(t, int(packetsPath2.Load()), 5) // make sure that no more packets are sent on path 2 before switching to the path c2 := packetsPath2.Load() sendAndReceiveFile(t) // stream 6 require.Equal(t, packetsPath2.Load(), c2) time.Sleep(3 * rtt) // wait for ACKs // now switch and make sure that no packets are sent on path 1 require.NoError(t, path.Switch()) sendAndReceiveFile(t) // stream 10 c1 := packetsPath1.Load() require.Equal(t, c1, packetsPath1.Load()) require.Greater(t, packetsPath2.Load(), c2) require.Equal(t, tr2.Conn.LocalAddr(), conn.LocalAddr()) // switch back to the handshake path time.Sleep(3 * rtt) // wait for ACKs c1BeforeSwitch := packetsPath1.Load() c2BeforeSwitch := packetsPath2.Load() path2, err := conn.AddPath(tr1) require.NoError(t, err) require.NoError(t, path2.Probe(ctx)) time.Sleep(3 * rtt) // wait for ACKs require.NoError(t, path2.Switch()) sendAndReceiveFile(t) // stream 14 require.Greater(t, packetsPath1.Load(), c1BeforeSwitch) // some path probing might have happened require.Less(t, int(packetsPath2.Load()-c2BeforeSwitch), 20) require.Equal(t, tr1.Conn.LocalAddr(), conn.LocalAddr()) } quic-go-0.59.0/integrationtests/self/datagram_test.go000066400000000000000000000175621513066070600226750ustar00rootroot00000000000000package self_test import ( "bytes" "context" mrand "math/rand/v2" "net" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDatagramNegotiation(t *testing.T) { t.Run("server enable, client enable", func(t *testing.T) { testDatagramNegotiation(t, true, true) }) t.Run("server enable, client disable", func(t *testing.T) { testDatagramNegotiation(t, true, false) }) t.Run("server disable, client enable", func(t *testing.T) { testDatagramNegotiation(t, false, true) }) t.Run("server disable, client disable", func(t *testing.T) { testDatagramNegotiation(t, false, false) }) } func testDatagramNegotiation(t *testing.T, serverEnableDatagram, clientEnableDatagram bool) { server, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{EnableDatagrams: serverEnableDatagram}), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientConn, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{EnableDatagrams: clientEnableDatagram}), ) require.NoError(t, err) defer clientConn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") serverState := serverConn.ConnectionState().SupportsDatagrams clientState := clientConn.ConnectionState().SupportsDatagrams require.Equal(t, serverEnableDatagram, serverState.Local, "server local datagram support") require.Equal(t, clientEnableDatagram, serverState.Remote, "server view of client datagram support") require.Equal(t, clientEnableDatagram, clientState.Local, "client local datagram support") require.Equal(t, serverEnableDatagram, clientState.Remote, "client view of server datagram support") if clientEnableDatagram { require.NoError(t, serverConn.SendDatagram([]byte("foo"))) datagram, err := clientConn.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, []byte("foo"), datagram) } else { require.Error(t, serverConn.SendDatagram([]byte("foo"))) } if serverEnableDatagram { require.NoError(t, clientConn.SendDatagram([]byte("bar"))) datagram, err := serverConn.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, []byte("bar"), datagram) } else { require.Error(t, clientConn.SendDatagram([]byte("bar"))) } } func TestDatagramSizeLimit(t *testing.T) { const maxDatagramSize = 456 originalMaxDatagramSize := wire.MaxDatagramSize wire.MaxDatagramSize = maxDatagramSize t.Cleanup(func() { wire.MaxDatagramSize = originalMaxDatagramSize }) server, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{EnableDatagrams: true}), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientConn, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{EnableDatagrams: true}), ) require.NoError(t, err) defer clientConn.CloseWithError(0, "") err = clientConn.SendDatagram(bytes.Repeat([]byte("a"), maxDatagramSize+100)) // definitely too large require.Error(t, err) var sizeErr *quic.DatagramTooLargeError require.ErrorAs(t, err, &sizeErr) require.InDelta(t, sizeErr.MaxDatagramPayloadSize, maxDatagramSize, 10) require.NoError(t, clientConn.SendDatagram(bytes.Repeat([]byte("b"), int(sizeErr.MaxDatagramPayloadSize)))) require.Error(t, clientConn.SendDatagram(bytes.Repeat([]byte("c"), int(sizeErr.MaxDatagramPayloadSize+1)))) serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") datagram, err := serverConn.ReceiveDatagram(ctx) require.NoError(t, err) require.Equal(t, bytes.Repeat([]byte("b"), int(sizeErr.MaxDatagramPayloadSize)), datagram) } func TestDatagramLoss(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 100 * time.Millisecond const numDatagrams = 100 const datagramSize = 500 clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} var droppedToClient, droppedToServer, total atomic.Int32 n := &simnet.Simnet{ Router: &directionAwareDroppingRouter{ ClientAddr: clientAddr, ServerAddr: serverAddr, Drop: func(d direction, p simnet.Packet) bool { if wire.IsLongHeaderPacket(p.Data[0]) { // don't drop Long Header packets return false } if len(p.Data) < datagramSize { // don't drop ACK-only packets return false } total.Add(1) // drop about 20% of Short Header packets with DATAGRAM frames if mrand.Int()%5 == 0 { switch d { case directionToClient: droppedToClient.Add(1) case directionToServer: droppedToServer.Add(1) } return true } return false }, }, } settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} clientPacketConn := n.NewEndpoint(clientAddr, settings) defer clientPacketConn.Close() serverPacketConn := n.NewEndpoint(serverAddr, settings) defer serverPacketConn.Close() require.NoError(t, n.Start()) defer n.Close() server, err := quic.Listen( serverPacketConn, getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), ) require.NoError(t, err) defer server.Close() const sendInterval = time.Second // send a datagram every second ctx, cancel := context.WithTimeout(context.Background(), (numDatagrams+10)*sendInterval) defer cancel() clientConn, err := quic.Dial( ctx, clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}), ) require.NoError(t, err) defer clientConn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") var clientDatagrams, serverDatagrams int clientErrChan := make(chan error, 1) go func() { defer close(clientErrChan) for { if _, err := clientConn.ReceiveDatagram(ctx); err != nil { clientErrChan <- err return } clientDatagrams++ } }() for i := range numDatagrams { payload := bytes.Repeat([]byte{uint8(i)}, datagramSize) require.NoError(t, clientConn.SendDatagram(payload)) require.NoError(t, serverConn.SendDatagram(payload)) time.Sleep(sendInterval) } serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) for { if _, err := serverConn.ReceiveDatagram(ctx); err != nil { serverErrChan <- err return } serverDatagrams++ } }() select { case err := <-clientErrChan: require.ErrorIs(t, err, context.DeadlineExceeded) case <-time.After(5 * numDatagrams * sendInterval): t.Fatal("timeout") } select { case err := <-serverErrChan: require.ErrorIs(t, err, context.DeadlineExceeded) case <-time.After(5 * numDatagrams * sendInterval): t.Fatal("timeout") } numDroppedToClient := droppedToClient.Load() numDroppedToServer := droppedToServer.Load() t.Logf("dropped %d to client and %d to server out of %d packets", numDroppedToClient, numDroppedToServer, total.Load()) assert.NotZero(t, numDroppedToClient) assert.NotZero(t, numDroppedToServer) t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams) assert.EqualValues(t, numDatagrams-numDroppedToServer, serverDatagrams, "datagrams received by the server") t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams) assert.EqualValues(t, numDatagrams-numDroppedToClient, clientDatagrams, "datagrams received by the client") }) } quic-go-0.59.0/integrationtests/self/deadline_test.go000066400000000000000000000132241513066070600226510ustar00rootroot00000000000000package self_test import ( "bytes" "context" "fmt" "io" "net" "testing" "time" "github.com/quic-go/quic-go" "github.com/stretchr/testify/require" ) func setupDeadlineTest(t *testing.T) (serverStr, clientStr *quic.Stream) { t.Helper() server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) t.Cleanup(func() { server.Close() }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) t.Cleanup(func() { conn.CloseWithError(0, "") }) clientStr, err = conn.OpenStream() require.NoError(t, err) _, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream require.NoError(t, err) serverConn, err := server.Accept(ctx) require.NoError(t, err) t.Cleanup(func() { serverConn.CloseWithError(0, "") }) serverStr, err = serverConn.AcceptStream(ctx) require.NoError(t, err) _, err = serverStr.Read([]byte{0}) require.NoError(t, err) return serverStr, clientStr } func TestReadDeadlineSync(t *testing.T) { serverStr, clientStr := setupDeadlineTest(t) const timeout = time.Millisecond errChan := make(chan error, 1) go func() { _, err := serverStr.Write(PRDataLong) errChan <- err }() var bytesRead int var timeoutCounter int buf := make([]byte, 1<<10) data := make([]byte, len(PRDataLong)) clientStr.SetReadDeadline(time.Now().Add(timeout)) for bytesRead < len(PRDataLong) { n, err := clientStr.Read(buf) if nerr, ok := err.(net.Error); ok && nerr.Timeout() { timeoutCounter++ clientStr.SetReadDeadline(time.Now().Add(timeout)) } else { require.NoError(t, err) } copy(data[bytesRead:], buf[:n]) bytesRead += n } require.Equal(t, PRDataLong, data) // make sure the test actually worked and Read actually ran into the deadline a few times t.Logf("ran into deadline %d times", timeoutCounter) require.GreaterOrEqual(t, timeoutCounter, 10) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestReadDeadlineAsync(t *testing.T) { serverStr, clientStr := setupDeadlineTest(t) const timeout = time.Millisecond errChan := make(chan error, 1) go func() { _, err := serverStr.Write(PRDataLong) errChan <- err }() var bytesRead int var timeoutCounter int buf := make([]byte, 1<<10) data := make([]byte, len(PRDataLong)) received := make(chan struct{}) go func() { for { select { case <-received: return default: time.Sleep(timeout) } clientStr.SetReadDeadline(time.Now().Add(timeout)) } }() for bytesRead < len(PRDataLong) { n, err := clientStr.Read(buf) if nerr, ok := err.(net.Error); ok && nerr.Timeout() { timeoutCounter++ } else { require.NoError(t, err) } copy(data[bytesRead:], buf[:n]) bytesRead += n } require.Equal(t, PRDataLong, data) close(received) // make sure the test actually worked and Read actually ran into the deadline a few times t.Logf("ran into deadline %d times", timeoutCounter) require.GreaterOrEqual(t, timeoutCounter, 10) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestWriteDeadlineSync(t *testing.T) { serverStr, clientStr := setupDeadlineTest(t) const timeout = time.Millisecond errChan := make(chan error, 1) go func() { defer close(errChan) data, err := io.ReadAll(serverStr) if err != nil { errChan <- err } if !bytes.Equal(PRDataLong, data) { errChan <- fmt.Errorf("data mismatch") } }() var bytesWritten int var timeoutCounter int clientStr.SetWriteDeadline(time.Now().Add(timeout)) for bytesWritten < len(PRDataLong) { n, err := clientStr.Write(PRDataLong[bytesWritten:]) if nerr, ok := err.(net.Error); ok && nerr.Timeout() { timeoutCounter++ clientStr.SetWriteDeadline(time.Now().Add(timeout)) } else { require.NoError(t, err) } bytesWritten += n } clientStr.Close() // make sure the test actually worked and Write actually ran into the deadline a few times t.Logf("ran into deadline %d times", timeoutCounter) require.GreaterOrEqual(t, timeoutCounter, 10) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestWriteDeadlineAsync(t *testing.T) { serverStr, clientStr := setupDeadlineTest(t) const timeout = time.Millisecond errChan := make(chan error, 1) go func() { defer close(errChan) data, err := io.ReadAll(serverStr) if err != nil { errChan <- err } if !bytes.Equal(PRDataLong, data) { errChan <- fmt.Errorf("data mismatch") } }() clientStr.SetWriteDeadline(time.Now().Add(timeout)) readDone := make(chan struct{}) deadlineDone := make(chan struct{}) go func() { defer close(deadlineDone) for { select { case <-readDone: return default: time.Sleep(timeout) } clientStr.SetWriteDeadline(time.Now().Add(timeout)) } }() var bytesWritten int var timeoutCounter int clientStr.SetWriteDeadline(time.Now().Add(timeout)) for bytesWritten < len(PRDataLong) { n, err := clientStr.Write(PRDataLong[bytesWritten:]) if nerr, ok := err.(net.Error); ok && nerr.Timeout() { timeoutCounter++ } else { require.NoError(t, err) } bytesWritten += n } clientStr.Close() close(readDone) // make sure the test actually worked and Write actually ran into the deadline a few times t.Logf("ran into deadline %d times", timeoutCounter) require.GreaterOrEqual(t, timeoutCounter, 10) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } quic-go-0.59.0/integrationtests/self/drop_test.go000066400000000000000000000065731513066070600220610ustar00rootroot00000000000000package self_test import ( "context" "fmt" "net" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) func TestPacketDrops(t *testing.T) { for _, direction := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { t.Run(fmt.Sprintf("from %s", direction), func(t *testing.T) { testPacketDrops(t, direction) }) } } func testPacketDrops(t *testing.T, direction protocol.Perspective) { synctest.Test(t, func(t *testing.T) { const numMessages = 50 const rtt = 10 * time.Millisecond addrClient := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} addrServer := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} var numDroppedPackets atomic.Int32 messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond) dropDuration := randomDuration(messageInterval*3, 2*time.Second) dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2) startTime := time.Now() n := &simnet.Simnet{ Router: &droppingRouter{ Drop: func(p simnet.Packet) bool { switch p.To { case addrClient: if direction == protocol.PerspectiveClient { return false } case addrServer: if direction == protocol.PerspectiveServer { return false } } if wire.IsLongHeaderPacket(p.Data[0]) { // don't interfere with the handshake return false } drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration)) if drop { numDroppedPackets.Add(1) } return drop }, }, } settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} clientPacketConn := n.NewEndpoint(addrClient, settings) defer clientPacketConn.Close() serverPacketConn := n.NewEndpoint(addrServer, settings) defer serverPacketConn.Close() require.NoError(t, n.Start()) defer n.Close() t.Logf("sending a message every %s, %d times", messageInterval, numMessages) t.Logf("dropping packets for %s, after a delay of %s", dropDuration, dropDelay) ln, err := quic.Listen(serverPacketConn, getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() conn, err := quic.Dial(ctx, clientPacketConn, ln.Addr().(*net.UDPAddr), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := ln.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") serverStr, err := serverConn.OpenUniStream() require.NoError(t, err) errChan := make(chan error, 1) go func() { for i := range numMessages { time.Sleep(messageInterval) if _, err := serverStr.Write([]byte{uint8(i + 1)}); err != nil { errChan <- err return } } }() str, err := conn.AcceptUniStream(ctx) require.NoError(t, err) for i := range numMessages { b := []byte{0} n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 1, n) require.Equal(t, byte(i+1), b[0]) } numDropped := numDroppedPackets.Load() t.Logf("dropped %d packets", numDropped) require.NotZero(t, numDropped) }) } quic-go-0.59.0/integrationtests/self/early_data_test.go000066400000000000000000000035151513066070600232130ustar00rootroot00000000000000package self_test import ( "context" "io" "net" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/stretchr/testify/require" ) func TestEarlyData(t *testing.T) { const rtt = 80 * time.Millisecond ln, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() proxy := &quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, } require.NoError(t, proxy.Start()) defer proxy.Close() connChan := make(chan *quic.Conn) errChan := make(chan error) go func() { conn, err := ln.Accept(context.Background()) if err != nil { errChan <- err return } connChan <- conn }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientConn, err := quic.Dial(ctx, newUDPConnLocalhost(t), proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) var serverConn *quic.Conn select { case serverConn = <-connChan: case err := <-errChan: t.Fatalf("error accepting connection: %s", err) } str, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = str.Write([]byte("early data")) require.NoError(t, err) require.NoError(t, str.Close()) // the write should have completed before the handshake select { case <-serverConn.HandshakeComplete(): t.Fatal("handshake shouldn't be completed yet") default: } clientStr, err := clientConn.AcceptUniStream(context.Background()) require.NoError(t, err) data, err := io.ReadAll(clientStr) require.NoError(t, err) require.Equal(t, []byte("early data"), data) clientConn.CloseWithError(0, "") <-serverConn.Context().Done() } quic-go-0.59.0/integrationtests/self/handshake_context_test.go000066400000000000000000000200731513066070600245760ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "errors" "net" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/require" ) func TestHandshakeContextTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) defer cancel() conn := newUDPConnLocalhost(t) errChan := make(chan error, 1) go func() { _, err := quic.Dial(ctx, newUDPConnLocalhost(t), conn.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) errChan <- err }() require.ErrorIs(t, <-errChan, context.DeadlineExceeded) } func TestHandshakeCancellationError(t *testing.T) { ctx, cancel := context.WithCancelCause(context.Background()) errChan := make(chan error, 1) conn := newUDPConnLocalhost(t) go func() { _, err := quic.Dial(ctx, newUDPConnLocalhost(t), conn.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) errChan <- err }() cancel(errors.New("application cancelled")) require.EqualError(t, <-errChan, "application cancelled") } func TestConnContextOnServerSide(t *testing.T) { tlsGetConfigForClientContextChan := make(chan context.Context, 1) tlsGetCertificateContextChan := make(chan context.Context, 1) tracerContextChan := make(chan context.Context, 1) connContextChan := make(chan context.Context, 1) streamContextChan := make(chan context.Context, 1) tr := &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) { return context.WithValue(ctx, "foo", "bar"), nil }, } defer tr.Close() server, err := tr.Listen( &tls.Config{ GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { tlsGetConfigForClientContextChan <- info.Context() tlsConf := getTLSConfig() tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { tlsGetCertificateContextChan <- info.Context() return &tlsConf.Certificates[0], nil } return tlsConf, nil }, }, getQuicConfig(&quic.Config{ Tracer: func(ctx context.Context, _ bool, _ quic.ConnectionID) qlogwriter.Trace { tracerContextChan <- ctx return nil }, }), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() c, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) serverConn, err := server.Accept(ctx) require.NoError(t, err) connContextChan <- serverConn.Context() str, err := serverConn.OpenUniStream() require.NoError(t, err) streamContextChan <- str.Context() str.Write([]byte{1, 2, 3}) _, err = c.AcceptUniStream(ctx) require.NoError(t, err) c.CloseWithError(1337, "bye") checkContext := func(c <-chan context.Context, checkCancellationCause bool) { t.Helper() var ctx context.Context select { case ctx = <-c: case <-time.After(time.Second): t.Fatal("timeout waiting for context") } val := ctx.Value("foo") require.NotNil(t, val) v := val.(string) require.Equal(t, "bar", v) select { case <-ctx.Done(): case <-time.After(time.Second): t.Fatal("timeout waiting for context to be done") } if !checkCancellationCause { return } ctxErr := context.Cause(ctx) var appErr *quic.ApplicationError require.ErrorAs(t, ctxErr, &appErr) require.Equal(t, quic.ApplicationErrorCode(1337), appErr.ErrorCode) } checkContext(connContextChan, true) checkContext(tracerContextChan, true) checkContext(streamContextChan, true) // crypto/tls cancels the context when the TLS handshake completes. checkContext(tlsGetConfigForClientContextChan, false) checkContext(tlsGetCertificateContextChan, false) } func TestConnContextRejection(t *testing.T) { t.Run("rejecting", func(t *testing.T) { testConnContextRejection(t, true) }) t.Run("not rejecting", func(t *testing.T) { testConnContextRejection(t, false) }) } func testConnContextRejection(t *testing.T, reject bool) { tr := &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnContext: func(ctx context.Context, ci *quic.ClientInfo) (context.Context, error) { if reject { return nil, errors.New("rejecting connection") } return context.WithValue(ctx, "addr", ci.RemoteAddr), nil }, } defer tr.Close() server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() pc := newUDPConnLocalhost(t) c, err := quic.Dial(ctx, pc, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) if reject { require.ErrorIs(t, err, &quic.TransportError{Remote: true, ErrorCode: quic.ConnectionRefused}) return } require.NoError(t, err) defer c.CloseWithError(0, "") conn, err := server.Accept(ctx) require.NoError(t, err) require.Equal(t, pc.LocalAddr().String(), conn.Context().Value("addr").(net.Addr).String()) conn.CloseWithError(0, "") } // Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully. func TestConnContextFreshContext(t *testing.T) { tr := &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) { return context.Background(), nil }, } defer tr.Close() server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() errChan := make(chan error, 1) go func() { conn, err := server.Accept(context.Background()) if err != nil { errChan <- err return } conn.CloseWithError(1337, "bye") }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() c, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) select { case <-c.Context().Done(): case err := <-errChan: t.Fatalf("accept failed: %v", err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestContextOnClientSide(t *testing.T) { tlsServerConf := getTLSConfig() tlsServerConf.ClientAuth = tls.RequestClientCert server, err := quic.Listen(newUDPConnLocalhost(t), tlsServerConf, getQuicConfig(nil)) require.NoError(t, err) defer server.Close() tlsContextChan := make(chan context.Context, 1) tracerContextChan := make(chan context.Context, 1) tlsConf := getTLSClientConfig() tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { tlsContextChan <- info.Context() return &tlsServerConf.Certificates[0], nil } ctx, cancel := context.WithCancel(context.WithValue(context.Background(), "foo", "bar")) conn, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), tlsConf, getQuicConfig(&quic.Config{ Tracer: func(ctx context.Context, _ bool, _ quic.ConnectionID) qlogwriter.Trace { tracerContextChan <- ctx return nil }, }), ) require.NoError(t, err) cancel() // Make sure the connection context is not cancelled (even though derived from the ctx passed to Dial) select { case <-conn.Context().Done(): t.Fatal("context should not be cancelled") default: } checkContext := func(ctx context.Context, checkCancellationCause bool) { t.Helper() val := ctx.Value("foo") require.NotNil(t, val) require.Equal(t, "bar", val.(string)) if !checkCancellationCause { return } ctxErr := context.Cause(ctx) var appErr *quic.ApplicationError require.ErrorAs(t, ctxErr, &appErr) require.EqualValues(t, 1337, appErr.ErrorCode) } checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) { t.Helper() var ctx context.Context select { case ctx = <-c: case <-time.After(time.Second): t.Fatal("timeout waiting for context") } checkContext(ctx, checkCancellationCause) } str, err := conn.OpenUniStream() require.NoError(t, err) conn.CloseWithError(1337, "bye") checkContext(conn.Context(), true) checkContext(str.Context(), true) // crypto/tls cancels the context when the TLS handshake completes checkContextFromChan(tlsContextChan, false) checkContextFromChan(tracerContextChan, false) } quic-go-0.59.0/integrationtests/self/handshake_drop_test.go000066400000000000000000000270251513066070600240620ustar00rootroot00000000000000package self_test import ( "bytes" "context" "crypto/tls" "fmt" "io" mrand "math/rand/v2" "net" "runtime" "slices" "strings" "sync" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/testutils/events" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, err := quic.Dial( ctx, clientConn, ln.Addr(), clientConf, getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, DisablePathMTUDiscovery: true, }), ) require.NoError(t, err) defer conn.CloseWithError(0, "") str, err := conn.OpenUniStream() require.NoError(t, err) errChan := make(chan error, 1) go func() { defer str.Close() _, err := str.Write(data) errChan <- err }() serverConn, err := ln.Accept(ctx) require.NoError(t, err) serverStr, err := serverConn.AcceptUniStream(ctx) require.NoError(t, err) b, err := io.ReadAll(&readerWithTimeout{Reader: serverStr, Timeout: timeout}) require.NoError(t, err) require.Equal(t, b, data) serverConn.CloseWithError(0, "") return conn } func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, err := quic.Dial( ctx, clientConn, ln.Addr(), clientConf, getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, DisablePathMTUDiscovery: true, }), ) require.NoError(t, err) errChan := make(chan error, 1) go func() { defer close(errChan) defer conn.CloseWithError(0, "") str, err := conn.AcceptUniStream(ctx) if err != nil { errChan <- err return } b, err := io.ReadAll(&readerWithTimeout{Reader: str, Timeout: timeout}) if err != nil { errChan <- err return } if !bytes.Equal(b, data) { errChan <- fmt.Errorf("data mismatch: %x != %x", b, data) return } }() serverConn, err := ln.Accept(ctx) require.NoError(t, err) serverStr, err := serverConn.OpenUniStream() require.NoError(t, err) _, err = serverStr.Write(data) require.NoError(t, err) require.NoError(t, serverStr.Close()) select { case err := <-errChan: require.NoError(t, err) case <-time.After(timeout): t.Fatal("server connection not closed") } select { case <-conn.Context().Done(): case <-time.After(timeout): t.Fatal("server connection not closed") } return conn } func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, _ []byte) *quic.Conn { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, err := quic.Dial( ctx, clientConn, ln.Addr(), clientConf, getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, DisablePathMTUDiscovery: true, }), ) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := ln.Accept(ctx) require.NoError(t, err) serverConn.CloseWithError(0, "") return conn } func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool { var toClient, toServer atomic.Int32 return func(d direction, p simnet.Packet) bool { switch d { case directionToClient: c := toClient.Add(1) if d == dir || dir == directionBoth { return slices.Contains(ns, int(c)) } case directionToServer: c := toServer.Add(1) if dir == d || dir == directionBoth { return slices.Contains(ns, int(c)) } } return false } } func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool { const maxSequentiallyDropped = 10 var mx sync.Mutex var toClient, toServer int return func(d direction, p simnet.Packet) bool { drop := mrand.IntN(3) == 0 mx.Lock() defer mx.Unlock() // never drop more than 10 consecutive packets if d == directionToClient || d == directionBoth { if drop { toClient++ if toClient > maxSequentiallyDropped { drop = false } } if !drop { toClient = 0 } } if d == directionToServer || d == directionBoth { if drop { toServer++ if toServer > maxSequentiallyDropped { drop = false } } if !drop { toServer = 0 } } return drop } } func TestHandshakeWithPacketLoss(t *testing.T) { data := GeneratePRData(5000) const timeout = 2 * time.Minute const rtt = 20 * time.Millisecond type dropPattern string const ( dropPatternDrop1stPacket dropPattern = "drop 1st packet" dropPatternDropFirst3Packets dropPattern = "drop first 3 packets" dropPatternDropOneThirdOfPackets dropPattern = "drop 1/3 of packets" ) type testConfig struct { postQuantum bool longCertChain bool doRetry bool } for _, dir := range []direction{directionToClient, directionToServer, directionBoth} { for _, pattern := range []dropPattern{ dropPatternDrop1stPacket, dropPatternDropFirst3Packets, dropPatternDropOneThirdOfPackets, } { t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) { for _, conf := range []testConfig{ {postQuantum: false, longCertChain: false, doRetry: true}, {postQuantum: false, longCertChain: false, doRetry: false}, {postQuantum: false, longCertChain: true, doRetry: false}, {postQuantum: true, longCertChain: false, doRetry: false}, {postQuantum: true, longCertChain: true, doRetry: false}, } { for _, test := range []struct { name string fn func(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn }{ {"client speaks first", dropTestProtocolClientSpeaksFirst}, {"server speaks first", dropTestProtocolServerSpeaksFirst}, {"nobody speaks", dropTestProtocolNobodySpeaks}, } { t.Run(fmt.Sprintf("retry: %t/%s", conf.doRetry, test.name), func(t *testing.T) { synctest.Test(t, func(t *testing.T) { clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} var fn func(direction, simnet.Packet) bool switch pattern { case dropPatternDrop1stPacket: fn = dropCallbackDropNthPacket(dir, 1) case dropPatternDropFirst3Packets: fn = dropCallbackDropNthPacket(dir, 1, 2, 3) case dropPatternDropOneThirdOfPackets: fn = dropCallbackDropOneThird(dir) } var numDropped atomic.Int32 n := &simnet.Simnet{ Router: &directionAwareDroppingRouter{ ClientAddr: clientAddr, ServerAddr: serverAddr, Drop: func(d direction, p simnet.Packet) bool { drop := fn(d, p) if drop { numDropped.Add(1) } return drop }, }, } settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} clientConn := n.NewEndpoint(clientAddr, settings) defer clientConn.Close() serverConn := n.NewEndpoint(serverAddr, settings) defer serverConn.Close() require.NoError(t, n.Start()) defer n.Close() var tlsConf *tls.Config if conf.longCertChain { tlsConf = getTLSConfigWithLongCertChain() } else { tlsConf = getTLSConfig() } clientConf := getTLSClientConfig() if !conf.postQuantum { clientConf.CurvePreferences = []tls.CurveID{tls.CurveP384} } tr := &quic.Transport{ Conn: serverConn, VerifySourceAddress: func(net.Addr) bool { return conf.doRetry }, } defer tr.Close() ln, err := tr.Listen( tlsConf, getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, DisablePathMTUDiscovery: true, }), ) require.NoError(t, err) defer ln.Close() conn := test.fn(t, ln, clientConn, clientConf, timeout, data) if !strings.HasPrefix(runtime.Version(), "go1.24") { curveID := getCurveID(conn.ConnectionState().TLS) if conf.postQuantum { require.Equal(t, tls.X25519MLKEM768, curveID) } else { require.Equal(t, tls.CurveP384, curveID) } } if pattern != dropPatternDropOneThirdOfPackets { require.NotZero(t, numDropped.Load()) } t.Logf("dropped %d packets", numDropped.Load()) }) }) } } }) } } } func TestHandshakePacketBuffering(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 20 * time.Millisecond clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} var droppedFirst atomic.Bool n := &simnet.Simnet{ Router: &directionAwareDroppingRouter{ ClientAddr: clientAddr, ServerAddr: serverAddr, Drop: func(d direction, p simnet.Packet) bool { if droppedFirst.Load() { return false } if d == directionToClient && containsPacketType(p.Data, protocol.PacketTypeInitial) { droppedFirst.Store(true) return true } return false }, }, } settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} clientConn := n.NewEndpoint(clientAddr, settings) defer clientConn.Close() serverConn := n.NewEndpoint(serverAddr, settings) defer serverConn.Close() require.NoError(t, n.Start()) defer n.Close() var serverEventRecorder events.Recorder ln, err := quic.Listen( serverConn, getTLSConfig(), getQuicConfig(&quic.Config{Tracer: newTracer(&serverEventRecorder)}), ) require.NoError(t, err) defer ln.Close() var clientEventRecorder events.Recorder conn, err := quic.Dial( context.Background(), clientConn, ln.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{Tracer: newTracer(&clientEventRecorder)}), ) require.NoError(t, err) defer conn.CloseWithError(0, "") str, err := conn.OpenUniStream() require.NoError(t, err) data := []byte("foobar") _, err = str.Write(data) require.NoError(t, err) require.NoError(t, str.Close()) require.Empty(t, serverEventRecorder.Events(qlog.PacketBuffered{})) buffered := clientEventRecorder.Events(qlog.PacketBuffered{}) t.Logf("buffered packets: %d", len(buffered)) require.NotEmpty(t, buffered) receivedPackets := make(map[qlog.DatagramID][]qlog.PacketType) for _, ev := range clientEventRecorder.Events(qlog.PacketReceived{}) { id := ev.(qlog.PacketReceived).DatagramID receivedPackets[id] = append(receivedPackets[id], ev.(qlog.PacketReceived).Header.PacketType) } for _, ev := range buffered { id := ev.(qlog.PacketBuffered).DatagramID require.Contains(t, receivedPackets, id) require.Contains(t, receivedPackets[id], qlog.PacketTypeHandshake) } sconn, err := ln.Accept(context.Background()) require.NoError(t, err) defer sconn.CloseWithError(0, "") sstr, err := sconn.AcceptUniStream(context.Background()) require.NoError(t, err) b, err := io.ReadAll(sstr) require.NoError(t, err) require.Equal(t, data, b) require.Equal(t, rtt, sconn.ConnectionStats().SmoothedRTT) }) } quic-go-0.59.0/integrationtests/self/handshake_rtt_test.go000066400000000000000000000117751513066070600237340ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "io" "net" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/synctest" "github.com/stretchr/testify/require" ) func TestHandshakeRTTRetry(t *testing.T) { t.Run("retry", func(t *testing.T) { testHandshakeRTTRetry(t, true) }) t.Run("no retry", func(t *testing.T) { testHandshakeRTTRetry(t, false) }) } func testHandshakeRTTRetry(t *testing.T, doRetry bool) { var addrVerified bool rtts := testHandshakeMeasureHandshake(t, func(net.Addr) bool { return doRetry }, getTLSConfig(), getQuicConfig(&quic.Config{ GetConfigForClient: func(info *quic.ClientInfo) (*quic.Config, error) { addrVerified = info.AddrVerified return nil, nil }, }), ) if doRetry { require.True(t, addrVerified, "should have verified address") require.GreaterOrEqual(t, rtts, float64(2)) require.Less(t, rtts, float64(2.1)) } else { require.False(t, addrVerified, "should not have verified address") require.GreaterOrEqual(t, rtts, float64(1)) require.Less(t, rtts, float64(1.1)) } } func TestHandshakeRTTHelloRetryRequest(t *testing.T) { tlsConf := getTLSConfig() tlsConf.CurvePreferences = []tls.CurveID{tls.CurveP384} rtts := testHandshakeMeasureHandshake(t, nil, tlsConf, getQuicConfig(nil)) require.GreaterOrEqual(t, rtts, float64(2)) require.Less(t, rtts, float64(2.1)) } func testHandshakeMeasureHandshake(t *testing.T, verifySourceAddress func(net.Addr) bool, tlsConf *tls.Config, quicConf *quic.Config) float64 { var rtts float64 synctest.Test(t, func(t *testing.T) { const rtt = 100 * time.Millisecond clientPacketConn, serverPacketConn, close := newSimnetLink(t, rtt) defer close(t) tr := &quic.Transport{ Conn: serverPacketConn, VerifySourceAddress: verifySourceAddress, } addTracer(tr) defer tr.Close() ln, err := tr.Listen(tlsConf, quicConf) require.NoError(t, err) defer ln.Close() clientConfig := getQuicConfig(nil) start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 10*rtt) defer cancel() conn, err := quic.Dial( ctx, clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), clientConfig, ) require.NoError(t, err) defer conn.CloseWithError(0, "") rtts = time.Since(start).Seconds() / rtt.Seconds() }) return rtts } func TestHandshake05RTT(t *testing.T) { t.Run("using ListenEarly", func(t *testing.T) { testHandshake05RTT(t, true) }) t.Run("using Listen", func(t *testing.T) { testHandshake05RTT(t, false) }) } func testHandshake05RTT(t *testing.T, use05RTT bool) { synctest.Test(t, func(t *testing.T) { type accepter interface { Accept(context.Context) (*quic.Conn, error) } const rtt = 100 * time.Millisecond clientPacketConn, serverPacketConn, close := newSimnetLink(t, rtt) defer close(t) var ln accepter if use05RTT { var err error server, err := quic.ListenEarly(serverPacketConn, getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ln = server } else { var err error server, err := quic.Listen(serverPacketConn, getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ln = server } connChan := make(chan *quic.Conn, 1) errChan := make(chan error, 1) go func() { conn, err := ln.Accept(context.Background()) if err != nil { errChan <- err return } str, err := conn.OpenUniStream() if err != nil { errChan <- err return } if _, err := str.Write([]byte("foobar")); err != nil { errChan <- err return } if err := str.Close(); err != nil { errChan <- err return } connChan <- conn }() start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 10*rtt) defer cancel() conn, err := quic.Dial( ctx, clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil), ) require.NoError(t, err) defer conn.CloseWithError(0, "") rtts := time.Since(start).Seconds() / rtt.Seconds() require.GreaterOrEqual(t, rtts, float64(1)) require.Less(t, rtts, float64(1.1)) start = time.Now() select { case err := <-errChan: t.Fatal("failed to accept connection:", err) case conn := <-connChan: if !use05RTT { // the server finishes the handshake 0.5 RTTs later rtts = time.Since(start).Seconds() / rtt.Seconds() require.GreaterOrEqual(t, rtts, float64(0.5)) require.Less(t, rtts, float64(0.6)) } defer conn.CloseWithError(0, "") } // If 0.5 RTT was used, the message should be received immediately, // otherwise it should take 1 RTT. str, err := conn.AcceptUniStream(ctx) require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, []byte("foobar"), data) rtts = time.Since(start).Seconds() / rtt.Seconds() if use05RTT { require.GreaterOrEqual(t, rtts, float64(0)) require.Less(t, rtts, float64(0.1)) } else { require.GreaterOrEqual(t, rtts, float64(1)) require.Less(t, rtts, float64(1.1)) } }) } quic-go-0.59.0/integrationtests/self/handshake_test.go000066400000000000000000000616171513066070600230430ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "errors" "fmt" "io" "net" "runtime" "strings" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qtls" "github.com/stretchr/testify/require" ) type tokenStore struct { store quic.TokenStore gets chan<- string puts chan<- string } var _ quic.TokenStore = &tokenStore{} func newTokenStore(gets, puts chan<- string) quic.TokenStore { return &tokenStore{ store: quic.NewLRUTokenStore(10, 4), gets: gets, puts: puts, } } func (c *tokenStore) Put(key string, token *quic.ClientToken) { c.puts <- key c.store.Put(key, token) } func (c *tokenStore) Pop(key string) *quic.ClientToken { c.gets <- key return c.store.Pop(key) } func TestHandshakeAddrResolutionHelpers(t *testing.T) { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialAddr( ctx, fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") } func TestHandshake(t *testing.T) { for _, tt := range []struct { name string conf *tls.Config }{ {"short cert chain", getTLSConfig()}, {"long cert chain", getTLSConfigWithLongCertChain()}, } { t.Run(tt.name, func(t *testing.T) { server, err := quic.Listen(newUDPConnLocalhost(t), tt.conf, getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") }) } } func TestHandshakeServerMismatch(t *testing.T) { server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() conf := getTLSClientConfig() conf.ServerName = "foo.bar" ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err = quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), conf, getQuicConfig(nil)) require.Error(t, err) var transportErr *quic.TransportError require.True(t, errors.As(err, &transportErr)) require.True(t, transportErr.ErrorCode.IsCryptoError()) require.Contains(t, transportErr.Error(), "x509: certificate is valid for localhost, not foo.bar") var certErr *tls.CertificateVerificationError require.True(t, errors.As(transportErr, &certErr)) } func TestHandshakeCipherSuites(t *testing.T) { for _, suiteID := range []uint16{ tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256, } { t.Run(tls.CipherSuiteName(suiteID), func(t *testing.T) { reset := qtls.SetCipherSuite(suiteID) defer reset() ln, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := ln.Accept(context.Background()) require.NoError(t, err) defer serverConn.CloseWithError(0, "") serverStr, err := serverConn.OpenStream() require.NoError(t, err) errChan := make(chan error, 1) go func() { defer serverStr.Close() _, err = serverStr.Write(PRData) errChan <- err }() require.NoError(t, <-errChan) str, err := conn.AcceptStream(context.Background()) require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) require.Equal(t, suiteID, conn.ConnectionState().TLS.CipherSuite) }) } } func TestTLSGetConfigForClientError(t *testing.T) { tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr) defer tr.Close() tlsConf := &tls.Config{ GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { return nil, errors.New("nope") }, } ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err = quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.True(t, transportErr.ErrorCode.IsCryptoError()) } // Since we're not operating on a net.Conn, we need to jump through some hoops to set the addresses on the tls.ClientHelloInfo. // Use a recursive setup to test that this works under all conditions. func TestTLSConfigGetConfigForClientAddresses(t *testing.T) { var local, remote net.Addr var local2, remote2 net.Addr done := make(chan struct{}) tlsConf := &tls.Config{ GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { local = info.Conn.LocalAddr() remote = info.Conn.RemoteAddr() conf := getTLSConfig() conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { defer close(done) local2 = info.Conn.LocalAddr() remote2 = info.Conn.RemoteAddr() return &(conf.Certificates[0]), nil } return conf, nil }, } server, err := quic.Listen(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") select { case <-done: case <-time.After(time.Second): t.Fatal("timeout waiting for GetCertificate callback") } require.Equal(t, server.Addr(), local) require.Equal(t, conn.LocalAddr().(*net.UDPAddr).Port, remote.(*net.UDPAddr).Port) require.Equal(t, local, local2) require.Equal(t, remote, remote2) } func TestHandshakeFailsWithoutClientCert(t *testing.T) { tlsConf := getTLSConfig() tlsConf.ClientAuth = tls.RequireAndVerifyClientCert server, err := quic.Listen(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) // Usually, the error will occur after the client already finished the handshake. // However, there's a race condition here. The server's CONNECTION_CLOSE might be // received before the connection is returned, so we might already get the error while dialing. if err == nil { errChan := make(chan error, 1) go func() { _, err := conn.AcceptStream(context.Background()) errChan <- err }() err = <-errChan } require.Error(t, err) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.True(t, transportErr.ErrorCode.IsCryptoError()) require.Condition(t, func() bool { errStr := transportErr.Error() return strings.Contains(errStr, "tls: certificate required") || strings.Contains(errStr, "tls: bad certificate") }) } func TestClosedConnectionsInAcceptQueue(t *testing.T) { dialer := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer dialer.Close() server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // Create first connection conn1, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) conn2, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn2.CloseWithError(0, "") // close the first connection const appErrCode quic.ApplicationErrorCode = 12345 require.NoError(t, conn1.CloseWithError(appErrCode, "")) time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued and closed // accept all connections, and find the closed one var closedConn *quic.Conn for i := 0; i < 2; i++ { conn, err := server.Accept(ctx) require.NoError(t, err) if conn.Context().Err() != nil { require.Nil(t, closedConn, "only expected a single closed connection") closedConn = conn } } require.NotNil(t, closedConn, "expected one closed connection") _, err = closedConn.AcceptStream(context.Background()) var appErr *quic.ApplicationError require.ErrorAs(t, err, &appErr) require.Equal(t, appErrCode, appErr.ErrorCode) } func TestServerAcceptQueueOverflow(t *testing.T) { server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() dialer := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer dialer.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // fill up the accept queue for i := 0; i < protocol.MaxAcceptQueueSize; i++ { conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") } time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued // next connection should be rejected conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) _, err = conn.AcceptStream(ctx) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, quic.ConnectionRefused, transportErr.ErrorCode) // accept one connection to free up a spot _, err = server.Accept(ctx) require.NoError(t, err) // should be able to dial again conn2, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn2.CloseWithError(0, "") time.Sleep(scaleDuration(25 * time.Millisecond)) // but next connection should be rejected again conn3, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) _, err = conn3.AcceptStream(ctx) require.ErrorAs(t, err, &transportErr) require.Equal(t, quic.ConnectionRefused, transportErr.ErrorCode) } func TestHandshakeCloseListener(t *testing.T) { t.Run("using Transport.Listen", func(t *testing.T) { testHandshakeCloseListener(t, func(tlsConf *tls.Config) *quic.Listener { tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr) t.Cleanup(func() { tr.Close() }) ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) require.NoError(t, err) return ln }) }) t.Run("using Listen", func(t *testing.T) { conn := newUDPConnLocalhost(t) testHandshakeCloseListener(t, func(tlsConf *tls.Config) *quic.Listener { ln, err := quic.Listen(conn, tlsConf, getQuicConfig(nil)) require.NoError(t, err) return ln }) // make sure that the Transport didn't close the underlying connection conn2 := newUDPConnLocalhost(t) _, err := conn2.WriteTo([]byte("test"), conn2.LocalAddr()) require.NoError(t, err) conn2.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1000) n, err := conn2.Read(b) require.NoError(t, err) require.Equal(t, "test", string(b[:n])) }) // This test is somewhat slow (600ms), since the connection entries are kept for 3 PTOs. t.Run("using ListenAddr", func(t *testing.T) { var lnAddr *net.UDPAddr testHandshakeCloseListener(t, func(tlsConf *tls.Config) *quic.Listener { ln, err := quic.ListenAddr("127.0.0.1:0", tlsConf, getQuicConfig(nil)) require.NoError(t, err) lnAddr = ln.Addr().(*net.UDPAddr) return ln }) // make sure that the Transport closed the underlying connection if runtime.GOOS != "windows" { // this check doesn't work on Windows require.Eventually(t, func() bool { conn, err := net.DialUDP("udp", nil, lnAddr) require.NoError(t, err) defer conn.Close() _, err = conn.Write([]byte("test")) require.NoError(t, err) conn.SetReadDeadline(time.Now().Add(scaleDuration(10 * time.Millisecond))) _, err = conn.Read(make([]byte, 1000)) require.Error(t, err) return strings.Contains(err.Error(), "read: connection refused") }, time.Second, 50*time.Millisecond) } }) } func testHandshakeCloseListener(t *testing.T, createListener func(*tls.Config) *quic.Listener) { connQueued := make(chan struct{}) var sawFirst atomic.Bool tlsConf := &tls.Config{ GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { isFirst := sawFirst.CompareAndSwap(false, true) if isFirst { } else { // Sleep for a bit. // This allows the server to close the connection before the handshake completes. close(connQueued) time.Sleep(scaleDuration(10 * time.Millisecond)) } return getTLSConfig(), nil }, } ln := createListener(tlsConf) // dial the first connection ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") _, err = ln.Accept(ctx) require.NoError(t, err) errChan := make(chan error, 1) go func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) errChan <- err }() select { case <-connQueued: case <-time.After(scaleDuration(10 * time.Millisecond)): t.Fatal("timeout waiting for connection queued") } require.NoError(t, ln.Close()) select { case err := <-errChan: var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, quic.ConnectionRefused, transportErr.ErrorCode) case <-time.After(time.Second): t.Fatal("timeout waiting for handshaking connection to be rejected") } // the first connection should not be closed select { case <-conn.Context().Done(): t.Fatal("connection was closed") case <-time.After(scaleDuration(10 * time.Millisecond)): } } func TestALPN(t *testing.T) { ln, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() acceptChan := make(chan *quic.Conn, 2) go func() { for { conn, err := ln.Accept(context.Background()) if err != nil { return } acceptChan <- conn } }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), nil) require.NoError(t, err) cs := conn.ConnectionState() require.Equal(t, alpn, cs.TLS.NegotiatedProtocol) select { case c := <-acceptChan: require.Equal(t, alpn, c.ConnectionState().TLS.NegotiatedProtocol) case <-time.After(time.Second): t.Fatal("timeout waiting for server connection") } require.NoError(t, conn.CloseWithError(0, "")) // now try with a different ALPN tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{"foobar"} ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() _, err = quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), tlsConf, nil) require.Error(t, err) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.True(t, transportErr.ErrorCode.IsCryptoError()) require.Contains(t, transportErr.Error(), "no application protocol") } func TestTokensFromNewTokenFrames(t *testing.T) { t.Run("MaxTokenAge: 1 hour", func(t *testing.T) { testTokensFromNewTokenFrames(t, 0, true) }) // If unset, the default value is 24h. t.Run("MaxTokenAge: default", func(t *testing.T) { testTokensFromNewTokenFrames(t, 0, true) }) t.Run("MaxTokenAge: very short", func(t *testing.T) { testTokensFromNewTokenFrames(t, time.Microsecond, false) }) } func testTokensFromNewTokenFrames(t *testing.T, maxTokenAge time.Duration, expectTokenUsed bool) { addrVerifiedChan := make(chan bool, 2) quicConf := getQuicConfig(nil) quicConf.GetConfigForClient = func(info *quic.ClientInfo) (*quic.Config, error) { addrVerifiedChan <- info.AddrVerified return quicConf, nil } tr := &quic.Transport{Conn: newUDPConnLocalhost(t), MaxTokenAge: maxTokenAge} addTracer(tr) defer tr.Close() server, err := tr.Listen(getTLSConfig(), quicConf) require.NoError(t, err) defer server.Close() // dial the first connection and receive the token acceptChan := make(chan error, 2) go func() { _, err := server.Accept(context.Background()) acceptChan <- err _, err = server.Accept(context.Background()) acceptChan <- err }() gets := make(chan string, 2) puts := make(chan string, 2) ts := newTokenStore(gets, puts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{TokenStore: ts})) require.NoError(t, err) // verify token store was used select { case <-gets: case <-time.After(time.Second): t.Fatal("timeout waiting for token store get") } select { case <-puts: case <-time.After(time.Second): t.Fatal("timeout waiting for token store put") } select { case addrVerified := <-addrVerifiedChan: require.False(t, addrVerified) case <-time.After(time.Second): t.Fatal("timeout waiting for addr verified") } select { case <-acceptChan: case <-time.After(time.Second): t.Fatal("timeout waiting for accept") } // received a token. Close this connection. require.NoError(t, conn.CloseWithError(0, "")) time.Sleep(scaleDuration(5 * time.Millisecond)) conn, err = quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{TokenStore: ts})) require.NoError(t, err) defer conn.CloseWithError(0, "") select { case addrVerified := <-addrVerifiedChan: // this time, the address was verified using the token if expectTokenUsed { require.True(t, addrVerified) } else { require.False(t, addrVerified) } case <-time.After(time.Second): t.Fatal("timeout waiting for addr verified") } select { case <-gets: case <-time.After(time.Second): t.Fatal("timeout waiting for token store get") } select { case <-acceptChan: case <-time.After(time.Second): t.Fatal("timeout waiting for accept") } } func TestInvalidToken(t *testing.T) { const rtt = 10 * time.Millisecond // The validity period of the retry token is the handshake timeout, // which is twice the handshake idle timeout. // By setting the handshake timeout shorter than the RTT, the token will have // expired by the time it reaches the server. serverConfig := getQuicConfig(&quic.Config{HandshakeIdleTimeout: rtt / 5}) tr := &quic.Transport{ Conn: newUDPConnLocalhost(t), VerifySourceAddress: func(net.Addr) bool { return true }, } addTracer(tr) defer tr.Close() server, err := tr.Listen(getTLSConfig(), serverConfig) require.NoError(t, err) defer server.Close() proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: server.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, } require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err = quic.Dial(ctx, newUDPConnLocalhost(t), proxy.LocalAddr(), getTLSClientConfig(), nil) require.Error(t, err) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, quic.InvalidToken, transportErr.ErrorCode) } func TestGetConfigForClient(t *testing.T) { var calledFrom net.Addr serverConfig := getQuicConfig(&quic.Config{EnableDatagrams: true}) serverConfig.GetConfigForClient = func(info *quic.ClientInfo) (*quic.Config, error) { conf := serverConfig.Clone() conf.EnableDatagrams = true calledFrom = info.RemoteAddr return getQuicConfig(conf), nil } ln, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), serverConfig) require.NoError(t, err) acceptDone := make(chan struct{}) go func() { _, err := ln.Accept(context.Background()) require.NoError(t, err) close(acceptDone) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{EnableDatagrams: true})) require.NoError(t, err) defer conn.CloseWithError(0, "") cs := conn.ConnectionState() require.True(t, cs.SupportsDatagrams.Remote, "server should advertise datagram support") require.True(t, cs.SupportsDatagrams.Local, "client should have datagram support enabled") select { case <-acceptDone: case <-time.After(time.Second): t.Fatal("timeout waiting for accept") } require.NoError(t, ln.Close()) require.Equal(t, conn.LocalAddr().(*net.UDPAddr).Port, calledFrom.(*net.UDPAddr).Port) } func TestGetConfigForClientErrorsConnectionRejection(t *testing.T) { ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{ GetConfigForClient: func(info *quic.ClientInfo) (*quic.Config, error) { return nil, errors.New("rejected") }, }), ) require.NoError(t, err) acceptChan := make(chan bool, 1) go func() { _, err := ln.Accept(context.Background()) acceptChan <- err == nil }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err = quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.ConnectionRefused, transportErr.ErrorCode) // verify no connection was accepted ln.Close() require.False(t, <-acceptChan) } func TestNoPacketsSentWhenClientHelloFails(t *testing.T) { conn := newUDPConnLocalhost(t) packetChan := make(chan struct{}, 1) go func() { for { _, _, err := conn.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize)) if err != nil { return } select { case packetChan <- struct{}{}: default: } } }() tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{""} ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err := quic.Dial(ctx, newUDPConnLocalhost(t), conn.LocalAddr(), tlsConf, getQuicConfig(nil)) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.True(t, transportErr.ErrorCode.IsCryptoError()) require.Contains(t, err.Error(), "tls: invalid NextProtos value") // verify no packets were sent select { case <-packetChan: t.Fatal("received unexpected packet") case <-time.After(50 * time.Millisecond): // no packets received, as expected } } func TestServerTransportClose(t *testing.T) { tlsServerConf := getTLSConfig() tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} server, err := tr.Listen(tlsServerConf, getQuicConfig(nil)) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // the first conn is accepted by the server... conn1, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: scaleDuration(50 * time.Millisecond)}), ) require.NoError(t, err) sconn, err := server.Accept(ctx) require.NoError(t, err) require.Equal(t, conn1.LocalAddr(), sconn.RemoteAddr()) // ...the second conn isn't, it remains in the server's accept queue conn2, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: scaleDuration(50 * time.Millisecond)}), ) require.NoError(t, err) time.Sleep(scaleDuration(10 * time.Millisecond)) // closing the Transport abruptly terminates connections require.NoError(t, tr.Close()) select { case <-sconn.Context().Done(): require.ErrorIs(t, context.Cause(sconn.Context()), quic.ErrTransportClosed) case <-time.After(time.Second): t.Fatal("timeout") } // no CONNECTION_CLOSE frame is sent to the peers select { case <-conn1.Context().Done(): require.ErrorIs(t, context.Cause(conn1.Context()), &quic.IdleTimeoutError{}) case <-time.After(time.Second): t.Fatal("timeout") } select { case <-conn2.Context().Done(): require.ErrorIs(t, context.Cause(conn1.Context()), &quic.IdleTimeoutError{}) case <-time.After(time.Second): t.Fatal("timeout") } // Accept should error after the transport was closed ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() accepted, err := server.Accept(ctx) require.ErrorIs(t, err, quic.ErrTransportClosed) require.Nil(t, accepted) } quic-go-0.59.0/integrationtests/self/http_datagram_test.go000066400000000000000000000200261513066070600237210ustar00rootroot00000000000000package self_test import ( "bytes" "context" "encoding/binary" "fmt" "io" "net/http" "net/url" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/stretchr/testify/require" ) func TestHTTPSettings(t *testing.T) { mux := http.NewServeMux() port := startHTTPServer(t, mux) t.Run("server settings", func(t *testing.T) { tlsConf := tlsClientConfigWithoutServerName.Clone() tlsConf.NextProtos = []string{http3.NextProtoH3} conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", port), tlsConf, getQuicConfig(nil), ) require.NoError(t, err) defer conn.CloseWithError(0, "") var tr http3.Transport cc := tr.NewClientConn(conn) select { case <-cc.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("didn't receive HTTP/3 settings") } settings := cc.Settings() require.True(t, settings.EnableExtendedConnect) require.False(t, settings.EnableDatagrams) require.Empty(t, settings.Other) }) t.Run("client settings", func(t *testing.T) { connChan := make(chan http3.Settingser, 1) mux.HandleFunc("/settings", func(w http.ResponseWriter, r *http.Request) { connChan <- w.(http3.Settingser) w.WriteHeader(http.StatusOK) }) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), QUICConfig: getQuicConfig(&quic.Config{ MaxIdleTimeout: 10 * time.Second, EnableDatagrams: true, }), EnableDatagrams: true, AdditionalSettings: map[uint64]uint64{1337: 42}, } defer tr.Close() req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/settings", port), nil) require.NoError(t, err) _, err = tr.RoundTrip(req) require.NoError(t, err) var conn http3.Settingser select { case conn = <-connChan: case <-time.After(time.Second): t.Fatal("didn't receive HTTP/3 connection") } select { case <-conn.ReceivedSettings(): case <-time.After(time.Second): t.Fatal("didn't receive HTTP/3 settings") } settings := conn.Settings() require.NotNil(t, settings) require.True(t, settings.EnableDatagrams) require.False(t, settings.EnableExtendedConnect) require.Equal(t, uint64(42), settings.Other[1337]) }) } func dialAndOpenHTTPDatagramStream(t *testing.T, addr string) *http3.RequestStream { t.Helper() u, err := url.Parse(addr) require.NoError(t, err) tlsConf := getTLSClientConfigWithoutServerName() tlsConf.NextProtos = []string{http3.NextProtoH3} ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialAddr(ctx, u.Host, tlsConf, getQuicConfig(&quic.Config{EnableDatagrams: true})) require.NoError(t, err) t.Cleanup(func() { conn.CloseWithError(0, "") }) tr := http3.Transport{EnableDatagrams: true} t.Cleanup(func() { tr.Close() }) cc := tr.NewClientConn(conn) t.Cleanup(func() { cc.CloseWithError(0, "") }) str, err := cc.OpenRequestStream(ctx) require.NoError(t, err) req := &http.Request{ Method: http.MethodConnect, Proto: "datagrams", Host: u.Host, URL: u, } require.NoError(t, str.SendRequestHeader(req)) rsp, err := str.ReadResponse() require.NoError(t, err) require.Equal(t, http.StatusOK, rsp.StatusCode) return str } func TestHTTPDatagrams(t *testing.T) { errChan := make(chan error, 1) const num = 5 datagramChan := make(chan struct{}, num) mux := http.NewServeMux() mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodConnect { w.WriteHeader(http.StatusMethodNotAllowed) return } s := w.(http3.Settingser) select { case <-s.ReceivedSettings(): case <-time.After(time.Second): w.WriteHeader(http.StatusBadRequest) return } if !s.Settings().EnableDatagrams { w.WriteHeader(http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) str := w.(http3.HTTPStreamer).HTTPStream() go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions for { if _, err := str.ReceiveDatagram(context.Background()); err != nil { errChan <- err return } datagramChan <- struct{}{} } }) port := startHTTPServer(t, mux, func(s *http3.Server) { s.EnableDatagrams = true }) str := dialAndOpenHTTPDatagramStream(t, fmt.Sprintf("https://localhost:%d/datagrams", port)) for i := range num { b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(i)) require.NoError(t, str.SendDatagram(bytes.Repeat(b, 100))) } var count int loop: for { select { case <-datagramChan: count++ if count >= num*4/5 { break loop } case err := <-errChan: t.Fatalf("receiving datagrams failed: %s", err) case <-time.After(time.Second): t.Fatal("timeout") } } str.CancelWrite(42) select { case err := <-errChan: var serr *quic.StreamError require.ErrorAs(t, err, &serr) require.Equal(t, quic.StreamErrorCode(42), serr.ErrorCode) case <-time.After(time.Second): t.Fatal("didn't receive error") } } func TestHTTPDatagramClose(t *testing.T) { errChan := make(chan error, 1) datagramChan := make(chan []byte, 1) mux := http.NewServeMux() mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) { s := w.(http3.Settingser) select { case <-s.ReceivedSettings(): case <-time.After(time.Second): w.WriteHeader(http.StatusBadRequest) return } if !s.Settings().EnableDatagrams { w.WriteHeader(http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) str := w.(http3.HTTPStreamer).HTTPStream() go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions for { data, err := str.ReceiveDatagram(context.Background()) if err != nil { errChan <- err return } datagramChan <- data } }) port := startHTTPServer(t, mux, func(s *http3.Server) { s.EnableDatagrams = true }) str := dialAndOpenHTTPDatagramStream(t, fmt.Sprintf("https://localhost:%d/datagrams", port)) go str.Read([]byte{0}) require.NoError(t, str.SendDatagram([]byte("foo"))) select { case data := <-datagramChan: require.Equal(t, []byte("foo"), data) case <-time.After(time.Second): t.Fatal("didn't receive datagram") } // signal that we're done sending str.Close() var resetErr error select { case resetErr = <-errChan: case <-time.After(time.Second): t.Fatal("didn't receive error") } require.Equal(t, io.EOF, resetErr) // make sure we can't send anymore require.Error(t, str.SendDatagram([]byte("foo"))) } func TestHTTPDatagramStreamReset(t *testing.T) { errChan := make(chan error, 1) datagramChan := make(chan []byte, 1) mux := http.NewServeMux() mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) { s := w.(http3.Settingser) select { case <-s.ReceivedSettings(): case <-time.After(time.Second): w.WriteHeader(http.StatusBadRequest) return } if !s.Settings().EnableDatagrams { w.WriteHeader(http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) str := w.(http3.HTTPStreamer).HTTPStream() go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions for { data, err := str.ReceiveDatagram(context.Background()) if err != nil { errChan <- err return } str.CancelRead(42) datagramChan <- data } }) port := startHTTPServer(t, mux, func(s *http3.Server) { s.EnableDatagrams = true }) str := dialAndOpenHTTPDatagramStream(t, fmt.Sprintf("https://localhost:%d/datagrams", port)) go str.Read([]byte{0}) require.NoError(t, str.SendDatagram([]byte("foo"))) select { case data := <-datagramChan: require.Equal(t, []byte("foo"), data) case <-time.After(time.Second): t.Fatal("didn't receive datagram") } var resetErr error select { case resetErr = <-errChan: case <-time.After(time.Second): t.Fatal("didn't receive error") } require.Equal(t, &quic.StreamError{ErrorCode: 42, Remote: false}, resetErr) var err error require.Eventually(t, func() bool { err = str.SendDatagram([]byte("foo")) return err != nil }, time.Second, 10*time.Millisecond) // make sure we can't send anymore require.Equal(t, &quic.StreamError{ErrorCode: 42, Remote: true}, err) } quic-go-0.59.0/integrationtests/self/http_hotswap_test.go000066400000000000000000000060431513066070600236310ustar00rootroot00000000000000package self_test import ( "io" "net" "net/http" "strconv" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/stretchr/testify/require" ) func TestHTTP3ServerHotswap(t *testing.T) { mux1 := http.NewServeMux() mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset. }) mux2 := http.NewServeMux() mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset. }) server1 := &http3.Server{ Handler: mux1, QUICConfig: getQuicConfig(nil), } server2 := &http3.Server{ Handler: mux2, QUICConfig: getQuicConfig(nil), } tlsConf := http3.ConfigureTLSConfig(getTLSConfig()) ln, err := quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) require.NoError(t, err) port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) newClient := func() *http.Client { return &http.Client{ Transport: &http3.Transport{ TLSClientConfig: getTLSClientConfig(), DisableCompression: true, QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), }, } } client := newClient() defer func() { require.NoError(t, ln.Close()) }() // open first server and make single request to it errChan1 := make(chan error, 1) go func() { errChan1 <- server1.ServeListener(ln) }() resp, err := client.Get("https://localhost:" + port + "/hello1") require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Hello, World 1!\n", string(body)) // open second server with same underlying listener errChan2 := make(chan error, 1) go func() { errChan2 <- server2.ServeListener(ln) }() time.Sleep(scaleDuration(20 * time.Millisecond)) select { case err := <-errChan1: t.Fatalf("server1 stopped unexpectedly: %v", err) case err := <-errChan2: t.Fatalf("server2 stopped unexpectedly: %v", err) default: } // now close first server require.NoError(t, server1.Close()) select { case err := <-errChan1: require.ErrorIs(t, err, http.ErrServerClosed) case <-time.After(5 * time.Second): t.Fatal("timed out waiting for server1 to stop") } require.NoError(t, client.Transport.(*http3.Transport).Close()) client = newClient() defer func() { require.NoError(t, client.Transport.(*http3.Transport).Close()) }() // verify that new connections are handled by the second server now resp, err = client.Get("https://localhost:" + port + "/hello2") require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Hello, World 2!\n", string(body)) // close the other server require.NoError(t, server2.Close()) select { case err := <-errChan2: require.ErrorIs(t, err, http.ErrServerClosed) case <-time.After(time.Second): t.Fatal("timed out waiting for server2 to stop") } } quic-go-0.59.0/integrationtests/self/http_qlog_test.go000066400000000000000000000044011513066070600231020ustar00rootroot00000000000000package self_test import ( "context" "fmt" "io" "net" "net/http" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" h3qlog "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHTTP3Qlog(t *testing.T) { serverTrace := newMockTrace() clientTrace := newMockTrace() mux := http.NewServeMux() mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) server := &http3.Server{ Handler: mux, TLSConfig: getTLSConfig(), QUICConfig: getQuicConfig(&quic.Config{ Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return serverTrace }, }), } conn := newUDPConnLocalhost(t) done := make(chan struct{}) go func() { defer close(done) server.Serve(conn) }() port := conn.LocalAddr().(*net.UDPAddr).Port tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), QUICConfig: getQuicConfig(&quic.Config{ Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return clientTrace }, }), } addDialCallback(t, tr) cl := &http.Client{Transport: tr} resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Hello, World!\n", string(body)) resp.Body.Close() assert.Equal(t, 2, clientTrace.OpenRecorders()) assert.Equal(t, 2, serverTrace.OpenRecorders()) tr.Close() server.Close() select { case <-done: case <-time.After(time.Second): t.Fatal("server didn't shut down") } // Recorders are closed in an AfterFunc, so we need to wait for them to be closed. assert.Eventually(t, func() bool { return clientTrace.OpenRecorders() == 0 }, time.Second, 10*time.Millisecond, "client recorders should be closed") assert.Eventually(t, func() bool { return serverTrace.OpenRecorders() == 0 }, time.Second, 10*time.Millisecond, "server recorders should be closed") assert.Equal(t, []string{h3qlog.EventSchema}, clientTrace.SchemasChecked) assert.Equal(t, []string{h3qlog.EventSchema}, serverTrace.SchemasChecked) } quic-go-0.59.0/integrationtests/self/http_raw_conn_test.go000066400000000000000000000110171513066070600237470ustar00rootroot00000000000000package self_test import ( "context" "fmt" "io" "net" "net/http" "sync" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) // This test tests the HTTP/3 raw connection functionality, // which is primarily used by WebTransport. func TestHTTPRawConn(t *testing.T) { const magicValue = 0x123456 synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, rtt) defer closeFn(t) ln, err := quic.ListenEarly( serverPacketConn, http3.ConfigureTLSConfig(getTLSConfig()), getQuicConfig(&quic.Config{EnableDatagrams: true}), ) require.NoError(t, err) defer ln.Close() start := time.Now() mux := http.NewServeMux() mux.HandleFunc("/data", func(w http.ResponseWriter, r *http.Request) { w.Write(PRData) }) server := &http3.Server{ Handler: mux, EnableDatagrams: true, } defer server.Close() // run the server in a separate Goroutine, so we can make sure that SETTINGS are sent in 0.5-RTT data errChan := make(chan error, 1) go func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() serverConn, err := ln.Accept(ctx) if err != nil { errChan <- err return } rawServerConn, err := server.NewRawServerConn(serverConn) if err != nil { errChan <- err return } var wg sync.WaitGroup wg.Add(2) // accept and handle unidirectional streams opened by the client go func() { defer wg.Done() for { str, err := serverConn.AcceptUniStream(context.Background()) if err != nil { return } go rawServerConn.HandleUnidirectionalStream(str) } }() // accept and handle bidirectional streams opened by the client go func() { defer wg.Done() for { str, err := serverConn.AcceptStream(context.Background()) if err != nil { return } v, _ := quicvarint.Peek(str) if v == magicValue { go func() { // read the previously peeked value quicvarint.Read(quicvarint.NewReader(str)) defer str.Close() io.Copy(str, str) }() } else { go rawServerConn.HandleRequestStream(str) } } }() wg.Wait() <-serverConn.Context().Done() errChan <- nil }() ctx, cancel := context.WithTimeout(context.Background(), time.Hour) defer cancel() clientConn, err := quic.Dial( ctx, clientPacketConn, serverPacketConn.LocalAddr(), http3.ConfigureTLSConfig(getTLSClientConfig()), getQuicConfig(&quic.Config{EnableDatagrams: true}), ) require.NoError(t, err) defer clientConn.CloseWithError(0, "") tr := &http3.Transport{ EnableDatagrams: true, } rawClientConn := tr.NewRawClientConn(clientConn) // accept and handle unidirectional streams opened by the server go func() { for { str, err := clientConn.AcceptUniStream(ctx) if err != nil { return } go rawClientConn.HandleUnidirectionalStream(str) } }() select { case <-rawClientConn.ReceivedSettings(): settings := rawClientConn.Settings() require.True(t, settings.EnableDatagrams) // the server sends SETTINGS in 0.5-RTT data, so they should be received after 1 RTT require.Equal(t, rtt, time.Since(start)) case <-time.After(time.Second): t.Fatal("timeout waiting for HTTP/3 settings") } req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/data", serverPacketConn.LocalAddr().(*net.UDPAddr)), nil) require.NoError(t, err) reqStr, err := rawClientConn.OpenRequestStream(ctx) require.NoError(t, err) require.NoError(t, reqStr.SendRequestHeader(req)) resp, err := reqStr.ReadResponse() require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) data, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, PRData, data) require.NoError(t, resp.Body.Close()) str, err := clientConn.OpenStream() require.NoError(t, err) b := quicvarint.Append(nil, magicValue) b = append(b, []byte("lorem ipsum dolor sit amet")...) _, err = str.Write(b) require.NoError(t, err) require.NoError(t, str.Close()) data, err = io.ReadAll(str) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum dolor sit amet"), data) clientConn.CloseWithError(0, "") select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout waiting for server to close") } }) } quic-go-0.59.0/integrationtests/self/http_shutdown_test.go000066400000000000000000000356641513066070600240320ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "fmt" "io" "net" "net/http" "net/url" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHTTPShutdown(t *testing.T) { mux := http.NewServeMux() var server *http3.Server port := startHTTPServer(t, mux, func(s *http3.Server) { server = s }) client := newHTTP3Client(t) mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { go func() { require.NoError(t, server.Close()) }() time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure the server started shutting down }) _, err := client.Get(fmt.Sprintf("https://localhost:%d/shutdown", port)) require.Error(t, err) var appErr *http3.Error require.ErrorAs(t, err, &appErr) require.Equal(t, http3.ErrCodeNoError, appErr.ErrorCode) } func TestGracefulShutdownShortRequest(t *testing.T) { var server *http3.Server mux := http.NewServeMux() port := startHTTPServer(t, mux, func(s *http3.Server) { server = s }) errChan := make(chan error, 1) proceed := make(chan struct{}) mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { go func() { defer close(errChan) errChan <- server.Shutdown(context.Background()) }() w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() <-proceed w.Write([]byte("shutdown")) }) connChan := make(chan *quic.Conn, 1) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (*quic.Conn, error) { addr, err := net.ResolveUDPAddr("udp", a) if err != nil { return nil, err } conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf) connChan <- conn return conn, err }, } t.Cleanup(func() { tr.Close() }) client := &http.Client{Transport: tr} ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/shutdown", port), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) var conn *quic.Conn select { case conn = <-connChan: default: t.Fatal("expected a connection") } type result struct { body []byte err error } resultChan := make(chan result, 1) go func() { body, err := io.ReadAll(resp.Body) resultChan <- result{body: body, err: err} }() select { case <-resultChan: t.Fatal("request body shouldn't have been read yet") case <-time.After(scaleDuration(10 * time.Millisecond)): } select { case <-conn.Context().Done(): t.Fatal("connection shouldn't have been closed") default: } // allow the request to proceed close(proceed) select { case res := <-resultChan: require.NoError(t, res.err) require.Equal(t, []byte("shutdown"), res.body) case <-time.After(time.Second): t.Fatal("timeout") } // now that the stream count dropped to 0, the client should close the connection select { case <-conn.Context().Done(): var appErr *quic.ApplicationError require.ErrorAs(t, context.Cause(conn.Context()), &appErr) assert.False(t, appErr.Remote) assert.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode) case <-time.After(time.Second): t.Fatal("timeout") } select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("shutdown did not complete") } } func TestGracefulShutdownIdleConnection(t *testing.T) { var server *http3.Server port := startHTTPServer(t, http.NewServeMux(), func(s *http3.Server) { server = s }) connChan := make(chan *quic.Conn, 1) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (*quic.Conn, error) { addr, err := net.ResolveUDPAddr("udp", a) if err != nil { return nil, err } conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf) connChan <- conn return conn, err }, } t.Cleanup(func() { tr.Close() }) client := &http.Client{Transport: tr} ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/", port), nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusNotFound, resp.StatusCode) require.NoError(t, resp.Body.Close()) var conn *quic.Conn select { case conn = <-connChan: default: t.Fatal("expected a connection") } // the connection should still be alive (and idle) select { case <-conn.Context().Done(): t.Fatal("connection shouldn't have been closed") default: } shutdownChan := make(chan error, 1) go func() { shutdownChan <- server.Shutdown(context.Background()) }() // since the connection is idle, the client should close it immediately select { case <-conn.Context().Done(): var appErr *quic.ApplicationError require.ErrorAs(t, context.Cause(conn.Context()), &appErr) assert.False(t, appErr.Remote) assert.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode) case <-time.After(time.Second): t.Fatal("timeout") } } func TestGracefulShutdownLongLivedRequest(t *testing.T) { delay := scaleDuration(25 * time.Millisecond) errChan := make(chan error, 1) requestChan := make(chan time.Duration, 1) var server *http3.Server mux := http.NewServeMux() port := startHTTPServer(t, mux, func(s *http3.Server) { server = s }) mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { start := time.Now() w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() // The request simulated here takes longer than the server's graceful shutdown period. // We expect it to be terminated once the server shuts down. go func() { ctx, cancel := context.WithTimeout(context.Background(), delay) defer cancel() errChan <- server.Shutdown(ctx) }() // measure how long it takes until the request errors for t := range time.NewTicker(delay / 10).C { if _, err := w.Write([]byte(t.String())); err != nil { requestChan <- time.Since(start) return } } }) start := time.Now() resp, err := newHTTP3Client(t).Get(fmt.Sprintf("https://localhost:%d/shutdown", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) _, err = io.Copy(io.Discard, resp.Body) require.Error(t, err) var h3Err *http3.Error require.ErrorAs(t, err, &h3Err) require.Equal(t, http3.ErrCodeNoError, h3Err.ErrorCode) took := time.Since(start) require.InDelta(t, delay.Seconds(), took.Seconds(), (delay / 2).Seconds()) // make sure that shutdown returned due to context deadline select { case err := <-errChan: require.ErrorIs(t, err, context.DeadlineExceeded) case <-time.After(time.Second): t.Fatal("shutdown did not return due to context deadline") } select { case requestDuration := <-requestChan: require.InDelta(t, delay.Seconds(), requestDuration.Seconds(), (delay / 2).Seconds()) case <-time.After(time.Second): t.Fatal("did not receive request duration") } } func TestGracefulShutdownPendingStreams(t *testing.T) { rtt := scaleDuration(25 * time.Millisecond) handlerChan := make(chan struct{}, 1) mux := http.NewServeMux() mux.HandleFunc("/helloworld", func(w http.ResponseWriter, r *http.Request) { handlerChan <- struct{}{} time.Sleep(rtt) w.Write([]byte("hello world")) }) var server *http3.Server port := startHTTPServer(t, mux, func(s *http3.Server) { server = s }) connChan := make(chan *quic.Conn, 1) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { a, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), a, tlsCfg, cfg) connChan <- conn return conn, err }, } cl := &http.Client{Transport: tr} proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, DelayPacket: func(_ quicproxy.Direction, _, _ net.Addr, _ []byte) time.Duration { return rtt }, } require.NoError(t, proxy.Start()) defer proxy.Close() errChan := make(chan error, 1) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/helloworld", proxy.LocalAddr()), nil) require.NoError(t, err) go func() { resp, err := cl.Do(req) if err != nil { errChan <- err return } if resp.StatusCode != http.StatusOK { errChan <- fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) } }() select { case <-handlerChan: case <-time.After(time.Second): t.Fatal("did not receive request") } shutdownChan := make(chan error, 1) ctx, cancel := context.WithCancel(context.Background()) go func() { shutdownChan <- server.Shutdown(ctx) }() time.Sleep(rtt / 2) // wait for the server to start shutting down var conn *quic.Conn select { case conn = <-connChan: case <-time.After(time.Second): t.Fatal("connection was not opened") } // make sure that the server rejects further requests for range 3 { str, err := conn.OpenStreamSync(ctx) require.NoError(t, err) str.Write([]byte("foobar")) select { case <-str.Context().Done(): case <-time.After(time.Second): t.Fatal("stream was not rejected") } _, err = str.Read(make([]byte, 10)) var serr *quic.StreamError require.ErrorAs(t, err, &serr) require.Equal(t, quic.StreamErrorCode(http3.ErrCodeRequestRejected), serr.ErrorCode) } cancel() select { case err := <-shutdownChan: require.ErrorIs(t, err, context.Canceled) case <-time.After(time.Second): t.Fatal("shutdown did not complete") } } func TestHTTP3ListenerClosing(t *testing.T) { t.Run("application listener", func(t *testing.T) { testHTTP3ListenerClosing(t, false, true) }) t.Run("listener created by the http3.Server", func(t *testing.T) { testHTTP3ListenerClosing(t, false, false) }) } func TestHTTP3ListenerGracefulShutdown(t *testing.T) { t.Run("application listener", func(t *testing.T) { testHTTP3ListenerClosing(t, true, true) }) t.Run("listener created by the http3.Server", func(t *testing.T) { testHTTP3ListenerClosing(t, true, false) }) } func testHTTP3ListenerClosing(t *testing.T, graceful, useApplicationListener bool) { dial := func(t *testing.T, ctx context.Context, u *url.URL) error { t.Helper() tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{http3.NextProtoH3} tr := &http3.Transport{TLSClientConfig: tlsConf} defer tr.Close() addDialCallback(t, tr) cl := &http.Client{Transport: tr} req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) require.NoError(t, err) resp, err := cl.Do(req) if err != nil { return err } defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) return nil } mux := http.NewServeMux() mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handlerChan := make(chan struct{}) mux.HandleFunc("/long", func(w http.ResponseWriter, r *http.Request) { <-handlerChan w.WriteHeader(http.StatusOK) }) tlsConf := http3.ConfigureTLSConfig(getTLSConfig()) server := &http3.Server{ Handler: mux, // the following values will be ignored when using ServeListener TLSConfig: tlsConf, QUICConfig: getQuicConfig(nil), Addr: "127.0.0.1:0", } serveChan := make(chan error, 1) var host string var ln *quic.EarlyListener // only set when using application listener if useApplicationListener { var err error ln, err = quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() host = ln.Addr().String() go func() { serveChan <- server.ServeListener(ln) }() } else { go func() { serveChan <- server.ListenAndServe() }() // The server is listening on a random port, and the only way to get the port // is to parse the Alt-Svc header. var port int require.Eventually(t, func() bool { hdr := make(http.Header) server.SetQUICHeaders(hdr) altSvc := hdr.Get("Alt-Svc") n, err := fmt.Sscanf(altSvc, `h3=":%d"`, &port) return err == nil && n == 1 }, time.Second, 10*time.Millisecond) host = fmt.Sprintf("127.0.0.1:%d", port) } u := &url.URL{Scheme: "https", Host: host, Path: "/ok"} ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() require.NoError(t, dial(t, ctx, u)) longReqChan := make(chan error, 1) shutdownChan := make(chan error, 1) if graceful { go func() { u := &url.URL{Scheme: "https", Host: host, Path: "/long"} ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() longReqChan <- dial(t, ctx, u) }() time.Sleep(scaleDuration(10 * time.Millisecond)) go func() { shutdownChan <- server.Shutdown(context.Background()) }() } else { require.NoError(t, server.Close()) } select { case err := <-serveChan: require.ErrorIs(t, err, http.ErrServerClosed) case <-time.After(time.Second): t.Fatal("server did not stop") } // If the listener was created by the http3.Server, it will now be closed. if !useApplicationListener { ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond)) defer cancel() require.ErrorIs(t, dial(t, ctx, u), context.DeadlineExceeded) } else { // If the listener was created by the application, it will not be closed, // and it can be used to accept new connections. errChan := make(chan error, 1) go func() { for { conn, err := ln.Accept(context.Background()) if err != nil { errChan <- err return } select { case <-conn.HandshakeComplete(): conn.CloseWithError(1337, "") case <-time.After(time.Second): errChan <- fmt.Errorf("connection did not complete handshake") } errChan <- nil } }() for range 2 { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() require.ErrorIs(t, dial(t, ctx, u), &http3.Error{ErrorCode: 1337, Remote: true}) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("server did not accept connection") } } } // the long request should have been terminated if graceful { select { case err := <-longReqChan: t.Fatalf("request should not have terminated: %v", err) case err := <-shutdownChan: t.Fatalf("graceful shutdown should not have returned: %v", err) case <-time.After(scaleDuration(10 * time.Millisecond)): } close(handlerChan) select { case err := <-longReqChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("long request did not terminate") } select { case err := <-shutdownChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("shutdown did not complete") } } } quic-go-0.59.0/integrationtests/self/http_test.go000066400000000000000000001230221513066070600220610ustar00rootroot00000000000000package self_test import ( "bufio" "bytes" "compress/gzip" "context" "crypto/tls" "errors" "fmt" "io" mrand "math/rand/v2" "net" "net/http" "net/http/httptrace" "net/textproto" "os" "strconv" "strings" "sync/atomic" "testing" "time" "golang.org/x/sync/errgroup" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/http3/qlog" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type neverEnding byte func (b neverEnding) Read(p []byte) (n int, err error) { for i := range p { p[i] = byte(b) } return len(p), nil } func randomString(length int) string { const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, length) for i := range b { n := mrand.IntN(len(alphabet)) b[i] = alphabet[n] } return string(b) } func startHTTPServer(t *testing.T, mux *http.ServeMux, opts ...func(*http3.Server)) (port int) { t.Helper() server := &http3.Server{ Handler: mux, TLSConfig: getTLSConfig(), QUICConfig: getQuicConfig(&quic.Config{Allow0RTT: true, EnableDatagrams: true}), } for _, opt := range opts { opt(server) } conn := newUDPConnLocalhost(t) done := make(chan struct{}) go func() { defer close(done) server.Serve(conn) }() t.Cleanup(func() { conn.Close() select { case <-done: case <-time.After(time.Second): t.Fatal("server didn't shut down") } }) return conn.LocalAddr().(*net.UDPAddr).Port } func newHTTP3Client(t *testing.T, opts ...func(*http3.Transport)) *http.Client { tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), DisableCompression: true, } for _, opt := range opts { opt(tr) } addDialCallback(t, tr) t.Cleanup(func() { tr.Close() }) return &http.Client{Transport: tr} } func TestHTTPGet(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) mux.HandleFunc("/long", func(w http.ResponseWriter, r *http.Request) { w.Write(PRDataLong) }) port := startHTTPServer(t, mux) cl := newHTTP3Client(t) t.Run("small", func(t *testing.T) { resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 2 * time.Second}) require.NoError(t, err) require.Equal(t, "Hello, World!\n", string(body)) }) t.Run("big", func(t *testing.T) { resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/long", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 10 * time.Second}) require.NoError(t, err) require.Equal(t, PRDataLong, body) }) } func TestHTTPPost(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { io.Copy(w, r.Body) }) port := startHTTPServer(t, mux) cl := newHTTP3Client(t) t.Run("small", func(t *testing.T) { resp, err := cl.Post( fmt.Sprintf("https://localhost:%d/echo", port), "text/plain", bytes.NewReader([]byte("Hello, world!")), ) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 2 * time.Second}) require.NoError(t, err) require.Equal(t, []byte("Hello, world!"), body) }) t.Run("big", func(t *testing.T) { resp, err := cl.Post( fmt.Sprintf("https://localhost:%d/echo", port), "text/plain", bytes.NewReader(PRData), ) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 10 * time.Second}) require.NoError(t, err) require.Equal(t, PRData, body) }) } func TestHTTPMultipleRequests(t *testing.T) { mux := http.NewServeMux() port := startHTTPServer(t, mux) t.Run("reading the response", func(t *testing.T) { mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) cl := newHTTP3Client(t) var eg errgroup.Group for i := 0; i < 200; i++ { eg.Go(func() error { resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) if err != nil { return err } if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 3 * time.Second}) if err != nil { return err } if string(body) != "Hello, World!\n" { return fmt.Errorf("unexpected body: %q", body) } return nil }) } require.NoError(t, eg.Wait()) }) t.Run("not reading the response", func(t *testing.T) { mux.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) { w.Write(PRData) }) cl := newHTTP3Client(t) const num = 150 for i := 0; i < num; i++ { resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/prdata", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.NoError(t, resp.Body.Close()) } }) } func TestContentLengthForSmallResponse(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "foo") io.WriteString(w, "bar") }) port := startHTTPServer(t, mux) resp, err := newHTTP3Client(t).Get(fmt.Sprintf("https://localhost:%d/small", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, "6", resp.Header.Get("Content-Length")) } func TestHTTPHeaders(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/headers/response", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", "bar") w.Header().Set("lorem", "ipsum") w.Header().Set("echo", r.Header.Get("echo")) }) port := startHTTPServer(t, mux) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/headers/response", port), nil) require.NoError(t, err) echoHdr := randomString(128) req.Header.Set("echo", echoHdr) resp, err := newHTTP3Client(t).Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, "bar", resp.Header.Get("foo")) require.Equal(t, "ipsum", resp.Header.Get("lorem")) require.Equal(t, echoHdr, resp.Header.Get("echo")) } func TestHTTPHeaderSizeLimitServer(t *testing.T) { t.Run("large HEADERS frame", func(t *testing.T) { const limit = 1024 hdr := make(http.Header) for range 20 { hdr.Add(randomString(50), randomString(50)) } headersFrameSize := testHTTPHeaderSizeLimitServer(t, hdr, limit) require.Greater(t, headersFrameSize, limit) }) t.Run("large decompressed HEADERS frame", func(t *testing.T) { const limit = 1024 hdr := make(http.Header) for range 200 { // This is a QPACK static table entry, so it will be compressed. hdr.Add("content-type", "text/plain;charset=utf-8") } headersFrameSize := testHTTPHeaderSizeLimitServer(t, hdr, limit) require.Less(t, headersFrameSize, limit) }) } func testHTTPHeaderSizeLimitServer(t *testing.T, hdr http.Header, limit int) (headersFrameSize int) { mux := http.NewServeMux() var handlerCalled bool mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) { handlerCalled = true }) port := startHTTPServer(t, mux, func(s *http3.Server) { s.MaxHeaderBytes = limit }) var eventRecorder events.Recorder cl := newHTTP3Client(t, func(tr *http3.Transport) { tr.QUICConfig = getQuicConfig(&quic.Config{ MaxIdleTimeout: 10 * time.Second, Tracer: newTracer(&eventRecorder), }) }) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/headers", port), nil) require.NoError(t, err) req.Header = hdr resp, err := cl.Do(req) require.NoError(t, err) require.Equal(t, http.StatusRequestHeaderFieldsTooLarge, resp.StatusCode) require.False(t, handlerCalled) for _, ev := range eventRecorder.Events(qlog.FrameCreated{}) { fc := ev.(qlog.FrameCreated) if _, ok := fc.Frame.Frame.(qlog.HeadersFrame); ok { headersFrameSize = fc.Raw.Length break } } return headersFrameSize } func TestHTTPHeaderSizeLimitClient(t *testing.T) { t.Run("large HEADERS frame", func(t *testing.T) { const limit = 1024 hdr := make(http.Header) for range 20 { hdr.Add(randomString(50), randomString(50)) } headersFrameSize, requestErr := testHTTPHeaderSizeLimitClient(t, hdr, limit) require.ErrorContains(t, requestErr, "http3: HEADERS frame too large") require.Greater(t, headersFrameSize, limit) }) t.Run("large decompressed HEADERS frame", func(t *testing.T) { const limit = 1024 hdr := make(http.Header) for range 200 { // This is a QPACK static table entry, so it will be compressed. hdr.Add("content-type", "text/plain;charset=utf-8") } headersFrameSize, requestErr := testHTTPHeaderSizeLimitClient(t, hdr, limit) require.ErrorContains(t, requestErr, "http3: headers too large") require.Less(t, headersFrameSize, limit) }) } func testHTTPHeaderSizeLimitClient(t *testing.T, hdr http.Header, limit int) (headersFrameSize int, requestErr error) { mux := http.NewServeMux() var handlerCalled atomic.Bool mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) { handlerCalled.Store(true) for k, v := range hdr { for _, val := range v { w.Header().Add(k, val) } } }) port := startHTTPServer(t, mux) var eventRecorder events.Recorder cl := newHTTP3Client(t, func(tr *http3.Transport) { tr.MaxResponseHeaderBytes = limit tr.QUICConfig = getQuicConfig(&quic.Config{ MaxIdleTimeout: 10 * time.Second, Tracer: newTracer(&eventRecorder), }) }, ) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/headers", port), nil) require.NoError(t, err) _, requestErr = cl.Do(req) require.Error(t, requestErr) require.True(t, handlerCalled.Load()) var found bool for _, ev := range eventRecorder.Events(qlog.FrameParsed{}) { fp := ev.(qlog.FrameParsed) if _, ok := fp.Frame.Frame.(qlog.HeadersFrame); ok { headersFrameSize = fp.Raw.PayloadLength found = true break } } require.True(t, found) return headersFrameSize, requestErr } func TestHTTPResponseTrailers(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/trailers", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Trailer", "AtEnd1, AtEnd2") w.Header().Add("Trailer", "Never") w.Header().Add("Trailer", "LAST") w.Header().Set("Content-Type", "text/plain; charset=utf-8") // normal header w.WriteHeader(http.StatusOK) w.Header().Set("AtEnd1", "value 1") io.WriteString(w, "This HTTP response has both headers before this text and trailers at the end.\n") w.(http.Flusher).Flush() w.Header().Set("AtEnd2", "value 2") io.WriteString(w, "More text\n") w.(http.Flusher).Flush() w.Header().Set("LAST", "value 3") w.Header().Set(http.TrailerPrefix+"Unannounced", "Surprise!") w.Header().Set("Late-Header", "No surprise!") }) port := startHTTPServer(t, mux) resp, err := newHTTP3Client(t).Get(fmt.Sprintf("https://localhost:%d/trailers", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Empty(t, resp.Header.Get("Trailer")) require.NotContains(t, resp.Header, "Atend1") require.NotContains(t, resp.Header, "Atend2") require.NotContains(t, resp.Header, "Never") require.NotContains(t, resp.Header, "Last") require.NotContains(t, resp.Header, "Late-Header") require.Equal(t, http.Header(map[string][]string{ "Atend1": nil, "Atend2": nil, "Never": nil, "Last": nil, }), resp.Trailer) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 3 * time.Second}) require.NoError(t, err) require.Equal(t, "This HTTP response has both headers before this text and trailers at the end.\nMore text\n", string(body)) for k := range resp.Header { require.NotContains(t, k, http.TrailerPrefix) } require.Equal(t, http.Header(map[string][]string{ "Atend1": {"value 1"}, "Atend2": {"value 2"}, "Last": {"value 3"}, "Unannounced": {"Surprise!"}, }), resp.Trailer) } func TestHTTPRequestTrailers(t *testing.T) { trailerChan := make(chan http.Header, 2) bodyChan := make(chan string, 1) mux := http.NewServeMux() mux.HandleFunc("/client-trailers", func(w http.ResponseWriter, r *http.Request) { trailerBeforeBody := make(http.Header) for k, v := range r.Trailer { trailerBeforeBody[k] = v } trailerChan <- trailerBeforeBody body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } bodyChan <- string(body) trailer := make(http.Header) for k, v := range r.Trailer { trailer[k] = v } trailerChan <- trailer w.WriteHeader(http.StatusOK) }) port := startHTTPServer(t, mux) pr, pw := io.Pipe() req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://localhost:%d/client-trailers", port), pr) require.NoError(t, err) req.Trailer = http.Header{ "Trailer1": nil, "Trailer2": {"to-be-updated"}, } go func() { // send the first half of the body pw.Write(PRData[:len(PRData)/2]) // then update the trailer values req.Trailer.Set("Trailer1", "foo") req.Trailer.Set("Trailer2", "bar") req.Trailer.Set("Trailer3", "baz") // send the rest of the body pw.Write(PRData[len(PRData)/2:]) pw.Close() }() resp, err := newHTTP3Client(t).Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) select { case trailersBefore := <-trailerChan: // trailers before body should have announced keys with nil values require.Equal(t, http.Header(map[string][]string{"Trailer1": nil, "Trailer2": nil}), trailersBefore) case <-time.After(time.Second): t.Fatal("timeout waiting for trailer announcement") } select { case body := <-bodyChan: require.Equal(t, string(PRData), body) case <-time.After(time.Second): t.Fatal("timeout waiting for body") } select { case trailers := <-trailerChan: require.Equal(t, http.Header(map[string][]string{ "Trailer1": {"foo"}, "Trailer2": {"bar"}, "Trailer3": {"baz"}, }), trailers) case <-time.After(time.Second): t.Fatal("timeout waiting for trailers") } } func TestHTTPErrAbortHandler(t *testing.T) { respChan := make(chan struct{}) mux := http.NewServeMux() mux.HandleFunc("/abort", func(w http.ResponseWriter, r *http.Request) { // no recover here as it will interfere with the handler io.WriteString(w, "foobar") w.(http.Flusher).Flush() // wait for the client to receive the response <-respChan panic(http.ErrAbortHandler) }) port := startHTTPServer(t, mux) resp, err := newHTTP3Client(t).Get(fmt.Sprintf("https://localhost:%d/abort", port)) close(respChan) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.Error(t, err) var h3Err *http3.Error require.True(t, errors.As(err, &h3Err)) require.Equal(t, http3.ErrCodeInternalError, h3Err.ErrorCode) // the body will be a prefix of what's written require.True(t, bytes.HasPrefix([]byte("foobar"), body)) } func TestHTTPGzip(t *testing.T) { mux := http.NewServeMux() var acceptEncoding string mux.HandleFunc("/hellogz", func(w http.ResponseWriter, r *http.Request) { acceptEncoding = r.Header.Get("Accept-Encoding") w.Header().Set("Content-Encoding", "gzip") w.Header().Set("foo", "bar") gw := gzip.NewWriter(w) defer gw.Close() _, err := gw.Write([]byte("Hello, World!\n")) require.NoError(t, err) }) port := startHTTPServer(t, mux) cl := newHTTP3Client(t) cl.Transport.(*http3.Transport).DisableCompression = false resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hellogz", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.True(t, resp.Uncompressed) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 3 * time.Second}) require.NoError(t, err) require.Equal(t, "Hello, World!\n", string(body)) // make sure the server received the Accept-Encoding header require.Equal(t, "gzip", acceptEncoding) } func TestHTTPDifferentOrigins(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/remote-addr", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-RemoteAddr", r.RemoteAddr) w.WriteHeader(http.StatusOK) }) port := startHTTPServer(t, mux) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), QUICConfig: getQuicConfig(nil), } t.Cleanup(func() { tr.Close() }) cl := &http.Client{Transport: tr} resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/remote-addr", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) addr1 := resp.Header.Get("X-RemoteAddr") require.NotEmpty(t, addr1) resp, err = cl.Get(fmt.Sprintf("https://127.0.0.1:%d/remote-addr", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) addr2 := resp.Header.Get("X-RemoteAddr") require.NotEmpty(t, addr2) require.Equal(t, addr1, addr2) } func TestHTTPServerIdleTimeout(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) idleTimeout := scaleDuration(10 * time.Millisecond) port := startHTTPServer(t, mux, func(s *http3.Server) { s.IdleTimeout = idleTimeout }) connChan := make(chan *quic.Conn, 1) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { conn, err := quic.DialAddrEarly(ctx, addr, tlsCfg, cfg) connChan <- conn return conn, err }, } t.Cleanup(func() { tr.Close() }) cl := &http.Client{Transport: tr} _, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) require.NoError(t, err) var conn *quic.Conn select { case conn = <-connChan: case <-time.After(time.Second): t.Fatal("connection was not opened") } select { case <-time.After(3 * idleTimeout): t.Fatal("connection was not closed") case <-conn.Context().Done(): } } func TestHTTPReestablishConnectionAfterDialError(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) port := startHTTPServer(t, mux) var dialCounter int cl := http.Client{ Transport: &http3.Transport{ TLSClientConfig: getTLSClientConfig(), Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, conf *quic.Config) (*quic.Conn, error) { dialCounter++ if dialCounter == 1 { // make the first dial fail return nil, assert.AnError } return quic.DialAddrEarly(ctx, addr, tlsConf, conf) }, }, } defer cl.Transport.(io.Closer).Close() _, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) require.ErrorIs(t, err, assert.AnError) resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) } func TestHTTPClientRequestContextCancellation(t *testing.T) { mux := http.NewServeMux() port := startHTTPServer(t, mux) cl := newHTTP3Client(t) t.Run("before response", func(t *testing.T) { mux.HandleFunc("/cancel-before", func(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() }) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel-before", port), nil) require.NoError(t, err) _, err = cl.Do(req) require.Error(t, err) require.ErrorIs(t, err, context.DeadlineExceeded) }) t.Run("after response", func(t *testing.T) { errChan := make(chan error, 1) mux.HandleFunc("/cancel-after", func(w http.ResponseWriter, r *http.Request) { // TODO(#4508): check for request context cancellations for { if _, err := io.WriteString(w, "foobar"); err != nil { errChan <- err return } } }) ctx, cancel := context.WithCancel(context.Background()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel-after", port), nil) require.NoError(t, err) resp, err := cl.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) cancel() select { case err := <-errChan: require.Error(t, err) var http3Err *http3.Error require.True(t, errors.As(err, &http3Err)) require.Equal(t, http3.ErrCodeRequestCanceled, http3Err.ErrorCode) require.True(t, http3Err.Remote) case <-time.After(time.Second): t.Fatal("handler was not called") } _, err = resp.Body.Read([]byte{0}) var http3Err *http3.Error require.True(t, errors.As(err, &http3Err)) require.Equal(t, http3.ErrCodeRequestCanceled, http3Err.ErrorCode) require.False(t, http3Err.Remote) }) } func TestHTTPDeadlines(t *testing.T) { const deadlineDelay = 50 * time.Millisecond mux := http.NewServeMux() port := startHTTPServer(t, mux) cl := newHTTP3Client(t) t.Run("read deadline", func(t *testing.T) { type result struct { body []byte err error } resultChan := make(chan result, 1) mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { rc := http.NewResponseController(w) require.NoError(t, rc.SetReadDeadline(time.Now().Add(deadlineDelay))) body, err := io.ReadAll(r.Body) resultChan <- result{body: body, err: err} io.WriteString(w, "ok") }) expectedEnd := time.Now().Add(deadlineDelay) resp, err := cl.Post( fmt.Sprintf("https://localhost:%d/read-deadline", port), "text/plain", neverEnding('a'), ) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 2 * deadlineDelay}) require.NoError(t, err) require.True(t, time.Now().After(expectedEnd)) require.Equal(t, "ok", string(body)) select { case result := <-resultChan: require.ErrorIs(t, result.err, os.ErrDeadlineExceeded) require.Contains(t, string(result.body), "aa") default: t.Fatal("handler was not called") } }) t.Run("write deadline", func(t *testing.T) { errChan := make(chan error, 1) mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { rc := http.NewResponseController(w) require.NoError(t, rc.SetWriteDeadline(time.Now().Add(deadlineDelay))) _, err := io.Copy(w, neverEnding('a')) errChan <- err }) expectedEnd := time.Now().Add(deadlineDelay) resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/write-deadline", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 2 * deadlineDelay}) require.NoError(t, err) require.True(t, time.Now().After(expectedEnd)) require.Contains(t, string(body), "aa") select { case err := <-errChan: require.ErrorIs(t, err, os.ErrDeadlineExceeded) case <-time.After(2 * deadlineDelay): t.Fatal("handler was not called") } }) } func TestHTTPServeQUICConn(t *testing.T) { tlsConf := getTLSConfig() tlsConf.NextProtos = []string{http3.NextProtoH3} ln, err := quic.Listen(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() mux := http.NewServeMux() mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) server := &http3.Server{ TLSConfig: tlsConf, QUICConfig: getQuicConfig(nil), Handler: mux, } errChan := make(chan error, 1) go func() { conn, err := ln.Accept(context.Background()) if err != nil { errChan <- fmt.Errorf("failed to accept QUIC connection: %w", err) return } errChan <- server.ServeQUICConn(conn) // returns once the client closes }() cl := newHTTP3Client(t) resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", ln.Addr().(*net.UDPAddr).Port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.NoError(t, cl.Transport.(io.Closer).Close()) select { case err := <-errChan: require.Error(t, err) require.ErrorContains(t, err, "accepting stream failed") case <-time.After(time.Second): t.Fatal("server didn't shut down") } } func TestHTTPContextFromQUIC(t *testing.T) { conn := newUDPConnLocalhost(t) tr := &quic.Transport{ Conn: conn, ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) { return context.WithValue(ctx, "foo", "bar"), nil }, } defer tr.Close() tlsConf := getTLSConfig() tlsConf.NextProtos = []string{http3.NextProtoH3} ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() mux := http.NewServeMux() ctxChan := make(chan context.Context, 1) mux.HandleFunc("/quic-conn-context", func(w http.ResponseWriter, r *http.Request) { ctxChan <- r.Context() }) server := &http3.Server{Handler: mux} go func() { c, err := ln.Accept(context.Background()) require.NoError(t, err) server.ServeQUICConn(c) }() cl := newHTTP3Client(t) resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/quic-conn-context", conn.LocalAddr().(*net.UDPAddr).Port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) select { case ctx := <-ctxChan: v, ok := ctx.Value("foo").(string) require.True(t, ok) require.Equal(t, "bar", v) default: t.Fatal("context not set") } } func TestHTTPConnContext(t *testing.T) { mux := http.NewServeMux() requestCtxChan := make(chan context.Context, 1) mux.HandleFunc("/context", func(w http.ResponseWriter, r *http.Request) { requestCtxChan <- r.Context() }) var server *http3.Server connCtxChan := make(chan context.Context, 1) port := startHTTPServer(t, mux, func(s *http3.Server) { server = s }, func(s *http3.Server) { s.ConnContext = func(ctx context.Context, c *quic.Conn) context.Context { connCtxChan <- ctx ctx = context.WithValue(ctx, "foo", "bar") return ctx } }, ) resp, err := newHTTP3Client(t).Get(fmt.Sprintf("https://localhost:%d/context", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) select { case ctx := <-connCtxChan: serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server) require.True(t, ok) require.Equal(t, server, serv) default: t.Fatal("handler was not called") } select { case ctx := <-requestCtxChan: v, ok := ctx.Value("foo").(string) require.True(t, ok) require.Equal(t, "bar", v) serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server) require.True(t, ok) require.Equal(t, server, serv) default: t.Fatal("handler was not called") } } func TestHTTPRemoteAddrContextKey(t *testing.T) { ctxChan := make(chan context.Context, 1) mux := http.NewServeMux() mux.HandleFunc("/remote-addr", func(w http.ResponseWriter, r *http.Request) { ctxChan <- r.Context() }) port := startHTTPServer(t, mux) resp, err := newHTTP3Client(t).Get(fmt.Sprintf("https://localhost:%d/remote-addr", port)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) select { case ctx := <-ctxChan: _, ok := ctx.Value(http3.RemoteAddrContextKey).(net.Addr) require.True(t, ok) require.Equal(t, "127.0.0.1", ctx.Value(http3.RemoteAddrContextKey).(*net.UDPAddr).IP.String()) default: t.Fatal("handler was not called") } } func TestHTTPStreamedRequests(t *testing.T) { errChan := make(chan error, 1) mux := http.NewServeMux() mux.HandleFunc("/echoline", func(w http.ResponseWriter, r *http.Request) { defer close(errChan) w.WriteHeader(200) w.(http.Flusher).Flush() reader := bufio.NewReader(r.Body) for { msg, err := reader.ReadString('\n') if err != nil { return } if _, err := io.WriteString(w, msg); err != nil { errChan <- err return } w.(http.Flusher).Flush() } }) port := startHTTPServer(t, mux) r, w := io.Pipe() req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("https://localhost:%d/echoline", port), r) require.NoError(t, err) client := newHTTP3Client(t) rsp, err := client.Do(req) require.NoError(t, err) require.Equal(t, 200, rsp.StatusCode) reader := bufio.NewReader(rsp.Body) for i := 0; i < 5; i++ { msg := fmt.Sprintf("Hello world, %d!\n", i) fmt.Fprint(w, msg) msgRcvd, err := reader.ReadString('\n') require.NoError(t, err) require.Equal(t, msg, msgRcvd) } require.NoError(t, req.Body.Close()) select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("handler did not complete") } } func TestHTTP1xxResponse(t *testing.T) { header1 := "; rel=preload; as=style" header2 := "; rel=preload; as=script" data := "1xx-test-data" mux := http.NewServeMux() mux.HandleFunc("/103-early-data", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", header1) w.Header().Add("Link", header2) w.WriteHeader(http.StatusEarlyHints) io.WriteString(w, data) w.WriteHeader(http.StatusOK) }) port := startHTTPServer(t, mux) var ( cnt int status int hdr textproto.MIMEHeader ) ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ Got1xxResponse: func(code int, header textproto.MIMEHeader) error { hdr = header status = code cnt++ return nil }, }) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/103-early-data", port), nil) require.NoError(t, err) resp, err := newHTTP3Client(t).Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, data, string(body)) require.Equal(t, http.StatusEarlyHints, status) require.Equal(t, []string{header1, header2}, hdr.Values("Link")) require.Equal(t, 1, cnt) require.Equal(t, []string{header1, header2}, resp.Header.Values("Link")) require.NoError(t, resp.Body.Close()) } func TestHTTP1xxTerminalResponse(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/101-switch-protocols", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("foo", "bar") w.WriteHeader(http.StatusSwitchingProtocols) }) port := startHTTPServer(t, mux) var ( cnt int status int ) ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ Got1xxResponse: func(code int, header textproto.MIMEHeader) error { status = code cnt++ return nil }, }) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/101-switch-protocols", port), nil) require.NoError(t, err) resp, err := newHTTP3Client(t).Do(req) require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) require.Equal(t, "bar", resp.Header.Get("Foo")) require.Zero(t, status) require.Zero(t, cnt) require.NoError(t, resp.Body.Close()) } func TestHTTP0RTT(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, strconv.FormatBool(!r.TLS.HandshakeComplete)) }) port := startHTTPServer(t, mux) var num0RTTPackets atomic.Uint32 proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, DelayPacket: func(_ quicproxy.Direction, _, _ net.Addr, data []byte) time.Duration { if containsPacketType(data, protocol.PacketType0RTT) { num0RTTPackets.Add(1) } return scaleDuration(25 * time.Millisecond) }, } require.NoError(t, proxy.Start()) defer proxy.Close() tlsConf := getTLSClientConfigWithoutServerName() puts := make(chan string, 10) tlsConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(10), nil, puts) tr := &http3.Transport{ TLSClientConfig: tlsConf, QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), DisableCompression: true, } defer tr.Close() addDialCallback(t, tr) proxyPort := proxy.LocalAddr().(*net.UDPAddr).Port req, err := http.NewRequest(http3.MethodGet0RTT, fmt.Sprintf("https://localhost:%d/0rtt", proxyPort), nil) require.NoError(t, err) rsp, err := tr.RoundTrip(req) require.NoError(t, err) require.Equal(t, 200, rsp.StatusCode) data, err := io.ReadAll(rsp.Body) require.NoError(t, err) require.Equal(t, "false", string(data)) require.Zero(t, num0RTTPackets.Load()) select { case <-puts: case <-time.After(time.Second): t.Fatal("did not receive session ticket") } tr2 := &http3.Transport{ TLSClientConfig: tr.TLSClientConfig, QUICConfig: tr.QUICConfig, DisableCompression: true, } defer tr2.Close() addDialCallback(t, tr2) rsp, err = tr2.RoundTrip(req) require.NoError(t, err) require.Equal(t, 200, rsp.StatusCode) data, err = io.ReadAll(rsp.Body) require.NoError(t, err) require.Equal(t, "true", string(data)) require.NotZero(t, num0RTTPackets.Load()) } func TestHTTPStreamer(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/httpstreamer", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) str := w.(http3.HTTPStreamer).HTTPStream() str.Write([]byte("foobar")) // Do this in a Go routine, so that the handler returns early. // This way, we can also check that the HTTP/3 doesn't close the stream. go func() { defer str.Close() _, _ = io.Copy(str, str) }() }) port := startHTTPServer(t, mux) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/httpstreamer", port), nil) require.NoError(t, err) tlsConf := getTLSClientConfigWithoutServerName() tlsConf.NextProtos = []string{http3.NextProtoH3} ctx, cancel := context.WithCancel(context.Background()) defer cancel() conn, err := quic.Dial(ctx, newUDPConnLocalhost(t), &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, tlsConf, getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") tr := http3.Transport{} addDialCallback(t, &tr) cc := tr.NewClientConn(conn) str, err := cc.OpenRequestStream(ctx) require.NoError(t, err) require.NoError(t, str.SendRequestHeader(req)) rsp, err := str.ReadResponse() require.NoError(t, err) require.Equal(t, 200, rsp.StatusCode) b := make([]byte, 6) _, err = io.ReadFull(str, b) require.NoError(t, err) require.Equal(t, []byte("foobar"), b) _, err = str.Write(PRData) require.NoError(t, err) require.NoError(t, str.Close()) repl, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, repl) } type blackHoleConn struct { net.PacketConn block atomic.Bool close chan struct{} } func (c *blackHoleConn) WriteTo(b []byte, addr net.Addr) (int, error) { return c.PacketConn.WriteTo(b, addr) } func (c *blackHoleConn) ReadFrom(b []byte) (int, net.Addr, error) { if c.block.Load() { <-c.close return 0, nil, errors.New("blocked") } n, _, err := c.PacketConn.ReadFrom(b) if c.block.Load() { <-c.close return 0, nil, errors.New("blocked") } return n, nil, err } func (c *blackHoleConn) Close() error { close(c.close) return c.PacketConn.Close() } func (c *blackHoleConn) StartBlocking() { c.block.Store(true) } func TestHTTPRequestRetryAfterIdleTimeout(t *testing.T) { t.Run("only cached conn", func(t *testing.T) { testHTTPRequestRetryAfterIdleTimeout(t, true) }) t.Run("allow re-dialing", func(t *testing.T) { testHTTPRequestRetryAfterIdleTimeout(t, false) }) } func testHTTPRequestRetryAfterIdleTimeout(t *testing.T, onlyCachedConn bool) { t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") mux := http.NewServeMux() mux.HandleFunc("/remote-addr", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr) }) port := startHTTPServer(t, mux, func(s *http3.Server) {}) firstConn := &blackHoleConn{PacketConn: newUDPConnLocalhost(t), close: make(chan struct{})} secondConn := newUDPConnLocalhost(t) conns := []net.PacketConn{firstConn, secondConn} require.NotEqual(t, firstConn.LocalAddr().String(), secondConn.LocalAddr().String()) idleTimeout := scaleDuration(10 * time.Millisecond) connChan := make(chan *quic.Conn, 2) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}), Dial: func(ctx context.Context, a string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { conn := conns[0] conns = conns[1:] addr, err := net.ResolveUDPAddr("udp", a) if err != nil { return nil, err } c, err := quic.DialEarly(ctx, conn, addr, tlsCfg, cfg) if err != nil { return nil, err } connChan <- c return c, nil }, DisableCompression: true, } t.Cleanup(func() { tr.Close() }) var headersCount int req, err := http.NewRequestWithContext( httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ WroteHeaders: func() { headersCount++ }, }), http.MethodGet, fmt.Sprintf("https://127.0.0.1:%d/remote-addr", port), // Add a body (wrappped so that http.NewRequest doesn't set the GetBody callback), // to make it impossible to retry this request. // This tests that the detection logic works properly: // If the request fails before the stream can be opened, it is always safe to retry. io.LimitReader(strings.NewReader("foobar"), 1000), ) require.NoError(t, err) resp, err := tr.RoundTripOpt(req, http3.RoundTripOpt{}) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, firstConn.LocalAddr().String(), string(body)) firstConn.StartBlocking() // wait for the connection to time out select { case c := <-connChan: select { case <-c.Context().Done(): case <-time.After(time.Second): t.Fatal("connection did not time out") } case <-time.After(time.Second): t.Fatal("no connection was created") } // second request should succeed after re-dialing resp, err = tr.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: onlyCachedConn}) if onlyCachedConn { require.EqualError(t, err, "http3: no cached connection was available") require.Len(t, conns, 1) // no second dial attempt require.Equal(t, 1, headersCount) return } require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err = io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 2 * time.Second}) require.NoError(t, err) require.Equal(t, secondConn.LocalAddr().String(), string(body)) require.Equal(t, 2, headersCount) require.Empty(t, conns) // make sure we dialed 2 connections } func TestHTTPRequestAfterGracefulShutdown(t *testing.T) { t.Run("Request.GetBody set", func(t *testing.T) { testHTTPRequestAfterGracefulShutdown(t, true) }) t.Run("Request.GetBody not set", func(t *testing.T) { testHTTPRequestAfterGracefulShutdown(t, false) }) } func testHTTPRequestAfterGracefulShutdown(t *testing.T, setGetBody bool) { t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") ln, err := quic.ListenEarly( newUDPConnLocalhost(t), http3.ConfigureTLSConfig(getTLSConfig()), getQuicConfig(nil), ) require.NoError(t, err) var inShutdown atomic.Bool proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(_ quicproxy.Direction, _, _ net.Addr, data []byte) time.Duration { if inShutdown.Load() { return scaleDuration(10 * time.Millisecond) } return scaleDuration(2 * time.Millisecond) }, } require.NoError(t, proxy.Start()) defer proxy.Close() mux2 := http.NewServeMux() mux2.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { data, _ := io.ReadAll(r.Body) w.Write(data) }) server2 := &http3.Server{Handler: mux2} done := make(chan struct{}) defer close(done) server1 := &http3.Server{Handler: http.NewServeMux()} go server1.ServeListener(ln) tlsConf := getTLSClientConfigWithoutServerName() tlsConf.NextProtos = []string{http3.NextProtoH3} var dialCount int tr := &http3.Transport{ TLSClientConfig: tlsConf, Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (*quic.Conn, error) { addr, err := net.ResolveUDPAddr("udp", a) if err != nil { return nil, err } dialCount++ return quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf) }, } t.Cleanup(func() { tr.Close() }) cl := &http.Client{Transport: tr} // first request to establish the connection resp, err := cl.Get(fmt.Sprintf("https://%s/", proxy.LocalAddr())) require.NoError(t, err) require.Equal(t, http.StatusNotFound, resp.StatusCode) // If the body is a strings.Reader, http.NewRequest automatically sets the GetBody callback. // This can be prevented by using a different kind of reader, e.g. the io.LimitReader. var headersCount int req, err := http.NewRequestWithContext( httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ WroteHeaders: func() { headersCount++ }, }), http.MethodGet, fmt.Sprintf("https://%s/echo", proxy.LocalAddr()), io.LimitReader(strings.NewReader("foobar"), 1000), ) require.NoError(t, err) if setGetBody { req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader("foobaz")), nil } } else { require.Nil(t, req.GetBody) } // By increasing the RTT, we make sure that the request is sent before the client receives the GOAWAY frame. inShutdown.Store(true) go server1.Shutdown(context.Background()) go server2.ServeListener(ln) defer server2.Close() resp, err = cl.Do(req) if !setGetBody { require.ErrorContains(t, err, "after Request.Body was written; define Request.GetBody to avoid this error") require.Equal(t, 1, dialCount) require.Equal(t, 1, headersCount) return } require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "foobaz", string(body)) require.Equal(t, 2, dialCount) require.Equal(t, 2, headersCount) } quic-go-0.59.0/integrationtests/self/http_trace_test.go000066400000000000000000000125311513066070600232410ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "fmt" "net" "net/http" "net/http/httptrace" "net/textproto" "testing" "time" "github.com/quic-go/quic-go/http3" "github.com/stretchr/testify/require" ) func TestHTTPClientTrace(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/client-trace", func(w http.ResponseWriter, r *http.Request) { time.Sleep(100 * time.Millisecond) w.WriteHeader(http.StatusContinue) }) port := startHTTPServer(t, mux) buf := make([]byte, 1) type event struct { Key string Args any } eventQueue := make(chan event, 100) wait100Continue := false trace := httptrace.ClientTrace{ GetConn: func(hostPort string) { eventQueue <- event{Key: "GetConn", Args: hostPort} }, GotConn: func(info httptrace.GotConnInfo) { eventQueue <- event{Key: "GotConn", Args: info} }, GotFirstResponseByte: func() { eventQueue <- event{Key: "GotFirstResponseByte"} }, Got100Continue: func() { eventQueue <- event{Key: "Got100Continue"} }, Got1xxResponse: func(code int, header textproto.MIMEHeader) error { eventQueue <- event{Key: "Got1xxResponse", Args: code} return nil }, DNSStart: func(di httptrace.DNSStartInfo) { eventQueue <- event{Key: "DNSStart", Args: di} }, DNSDone: func(di httptrace.DNSDoneInfo) { eventQueue <- event{Key: "DNSDone", Args: di} }, ConnectStart: func(network, addr string) { eventQueue <- event{Key: "ConnectStart", Args: map[string]string{"network": network, "addr": addr}} }, ConnectDone: func(network, addr string, err error) { eventQueue <- event{Key: "ConnectDone", Args: map[string]any{"network": network, "addr": addr, "err": err}} }, TLSHandshakeStart: func() { eventQueue <- event{Key: "TLSHandshakeStart"} }, TLSHandshakeDone: func(state tls.ConnectionState, err error) { eventQueue <- event{Key: "TLSHandshakeDone", Args: map[string]any{"state": state, "err": err}} }, WroteHeaderField: func(key string, value []string) { if key != ":authority" { return } eventQueue <- event{Key: "WroteHeaderField", Args: value[0]} }, WroteHeaders: func() { eventQueue <- event{Key: "WroteHeaders"} }, Wait100Continue: func() { wait100Continue = true }, WroteRequest: func(i httptrace.WroteRequestInfo) { eventQueue <- event{Key: "WroteRequest", Args: i} }, } ctx := httptrace.WithClientTrace(context.Background(), &trace) tr := &http3.Transport{ TLSClientConfig: getTLSClientConfigWithoutServerName(), QUICConfig: getQuicConfig(nil), } t.Cleanup(func() { tr.Close() }) cl := &http.Client{Transport: tr} req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/client-trace", port), nil) require.NoError(t, err) resp, err := cl.Do(req) close(eventQueue) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) events := make([]string, 0, len(eventQueue)) for e := range eventQueue { events = append(events, e.Key) switch e.Key { case "GetConn": require.Equal(t, fmt.Sprintf("localhost:%d", port), e.Args.(string)) case "GotConn": info := e.Args.(httptrace.GotConnInfo) require.Equal(t, fmt.Sprintf("127.0.0.1:%d", port), info.Conn.RemoteAddr().String()) host, _, err := net.SplitHostPort(info.Conn.LocalAddr().String()) require.NoError(t, err) require.Contains(t, []string{"::", "0.0.0.0"}, host) require.Panics(t, func() { info.Conn.Close() }) require.Panics(t, func() { info.Conn.Read(buf) }) require.Panics(t, func() { info.Conn.Write(buf) }) require.Panics(t, func() { info.Conn.SetDeadline(time.Now()) }) require.Panics(t, func() { info.Conn.SetReadDeadline(time.Now()) }) require.Panics(t, func() { info.Conn.SetWriteDeadline(time.Now()) }) case "Got1xxResponse": require.Equal(t, 100, e.Args.(int)) case "DNSStart": require.Equal(t, "localhost", e.Args.(httptrace.DNSStartInfo).Host) case "DNSDone": require.Condition(t, func() bool { localhost := net.IPv4(127, 0, 0, 1) localhostTo16 := localhost.To16() for _, addr := range e.Args.(httptrace.DNSDoneInfo).Addrs { if addr.IP.Equal(localhost) || addr.IP.Equal(localhostTo16) { return true } } return false }) case "ConnectStart": require.Equal(t, "udp", e.Args.(map[string]string)["network"]) require.Equal(t, fmt.Sprintf("127.0.0.1:%d", port), e.Args.(map[string]string)["addr"]) case "ConnectDone": require.Equal(t, "udp", e.Args.(map[string]any)["network"]) require.Equal(t, fmt.Sprintf("127.0.0.1:%d", port), e.Args.(map[string]any)["addr"]) require.Nil(t, e.Args.(map[string]any)["err"]) case "TLSHandshakeDone": require.Nil(t, e.Args.(map[string]any)["err"]) state := e.Args.(map[string]any)["state"].(tls.ConnectionState) require.Equal(t, 1, len(state.PeerCertificates)) require.Equal(t, "localhost", state.PeerCertificates[0].DNSNames[0]) case "WroteHeaderField": require.Equal(t, fmt.Sprintf("localhost:%d", port), e.Args.(string)) case "WroteRequest": require.NoError(t, e.Args.(httptrace.WroteRequestInfo).Err) } } require.Equal(t, []string{ "GetConn", "DNSStart", "DNSDone", "ConnectStart", "TLSHandshakeStart", "TLSHandshakeDone", "ConnectDone", "GotConn", "WroteHeaderField", "WroteHeaders", "WroteRequest", "GotFirstResponseByte", "Got1xxResponse", "Got100Continue", }, events) require.Falsef(t, wait100Continue, "wait 100 continue") // Note: not supported Expect: 100-continue } quic-go-0.59.0/integrationtests/self/key_update_test.go000066400000000000000000000046441513066070600232440ustar00rootroot00000000000000package self_test import ( "context" "io" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestKeyUpdates(t *testing.T) { reset := handshake.SetKeyUpdateInterval(1) // update keys as frequently as possible t.Cleanup(reset) countKeyPhases := func(events []qlogwriter.Event) (sent, received int) { lastKeyPhaseSend := protocol.KeyPhaseOne lastKeyPhaseReceive := protocol.KeyPhaseOne for _, ev := range events { switch ev := ev.(type) { case qlog.PacketSent: if ev.Header.KeyPhaseBit != lastKeyPhaseSend { sent++ lastKeyPhaseSend = ev.Header.KeyPhaseBit } case qlog.PacketReceived: if ev.Header.KeyPhaseBit != lastKeyPhaseReceive { received++ lastKeyPhaseReceive = ev.Header.KeyPhaseBit } } } return } server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), nil) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var eventRecorder events.Recorder conn, err := quic.Dial( ctx, newUDPConnLocalhost(t), server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{Tracer: newTracer(&eventRecorder)}), ) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") serverErrChan := make(chan error, 1) go func() { str, err := serverConn.OpenUniStream() if err != nil { serverErrChan <- err return } defer str.Close() if _, err := str.Write(PRDataLong); err != nil { serverErrChan <- err return } close(serverErrChan) }() str, err := conn.AcceptUniStream(ctx) require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRDataLong, data) require.NoError(t, conn.CloseWithError(0, "")) require.NoError(t, <-serverErrChan) keyPhasesSent, keyPhasesReceived := countKeyPhases(eventRecorder.Events()) t.Logf("Used %d key phases on outgoing and %d key phases on incoming packets.", keyPhasesSent, keyPhasesReceived) assert.Greater(t, keyPhasesReceived, 10) assert.InDelta(t, keyPhasesSent, keyPhasesReceived, 2) } quic-go-0.59.0/integrationtests/self/mitm_test.go000066400000000000000000000320141513066070600220500ustar00rootroot00000000000000package self_test import ( "context" "crypto/rand" "errors" "io" "math" mrand "math/rand/v2" "net" "sync" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/testutils" "github.com/stretchr/testify/require" ) const mitmTestConnIDLen = 6 func getTransportsForMITMTest(t *testing.T) (serverTransport, clientTransport *quic.Transport) { serverTransport = &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: mitmTestConnIDLen, } addTracer(serverTransport) t.Cleanup(func() { serverTransport.Close() }) clientTransport = &quic.Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: mitmTestConnIDLen, } addTracer(clientTransport) t.Cleanup(func() { clientTransport.Close() }) return serverTransport, clientTransport } func TestMITMInjectRandomPackets(t *testing.T) { t.Run("towards the server", func(t *testing.T) { testMITMInjectRandomPackets(t, quicproxy.DirectionIncoming) }) t.Run("towards the client", func(t *testing.T) { testMITMInjectRandomPackets(t, quicproxy.DirectionOutgoing) }) } func TestMITMDuplicatePackets(t *testing.T) { t.Run("towards the server", func(t *testing.T) { testMITMDuplicatePackets(t, quicproxy.DirectionIncoming) }) t.Run("towards the client", func(t *testing.T) { testMITMDuplicatePackets(t, quicproxy.DirectionOutgoing) }) } func TestMITCorruptPackets(t *testing.T) { t.Run("towards the server", func(t *testing.T) { testMITMCorruptPackets(t, quicproxy.DirectionIncoming) }) t.Run("towards the client", func(t *testing.T) { testMITMCorruptPackets(t, quicproxy.DirectionOutgoing) }) } func testMITMInjectRandomPackets(t *testing.T, direction quicproxy.Direction) { createRandomPacketOfSameType := func(b []byte) []byte { if wire.IsLongHeaderPacket(b[0]) { hdr, _, _, err := wire.ParsePacket(b) if err != nil { return nil } replyHdr := &wire.ExtendedHeader{ Header: wire.Header{ DestConnectionID: hdr.DestConnectionID, SrcConnectionID: hdr.SrcConnectionID, Type: hdr.Type, Version: hdr.Version, }, PacketNumber: protocol.PacketNumber(mrand.Int32N(math.MaxInt32 / 4)), PacketNumberLen: protocol.PacketNumberLen(mrand.IntN(4) + 1), } payloadLen := mrand.IntN(100) replyHdr.Length = protocol.ByteCount(mrand.IntN(payloadLen + 1)) data, err := replyHdr.Append(nil, hdr.Version) if err != nil { panic("failed to append header: " + err.Error()) } r := make([]byte, payloadLen) rand.Read(r) return append(data, r...) } // short header packet connID, err := wire.ParseConnectionID(b, mitmTestConnIDLen) if err != nil { return nil } _, pn, pnLen, _, err := wire.ParseShortHeader(b, mitmTestConnIDLen) if err != nil && !errors.Is(err, wire.ErrInvalidReservedBits) { // normally, ParseShortHeader is called after decrypting the header panic("failed to parse short header: " + err.Error()) } data, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.IntN(2))) if err != nil { return nil } payloadLen := mrand.IntN(100) r := make([]byte, payloadLen) rand.Read(r) return append(data, r...) } rtt := scaleDuration(10 * time.Millisecond) serverTransport, clientTransport := getTransportsForMITMTest(t) dropCallback := func(dir quicproxy.Direction, _, _ net.Addr, b []byte) bool { if dir != direction { return false } go func() { ticker := time.NewTicker(rtt / 10) defer ticker.Stop() for range 10 { switch direction { case quicproxy.DirectionIncoming: clientTransport.WriteTo(createRandomPacketOfSameType(b), serverTransport.Conn.LocalAddr()) case quicproxy.DirectionOutgoing: serverTransport.WriteTo(createRandomPacketOfSameType(b), clientTransport.Conn.LocalAddr()) } <-ticker.C } }() return false } runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback) } func testMITMDuplicatePackets(t *testing.T, direction quicproxy.Direction) { serverTransport, clientTransport := getTransportsForMITMTest(t) rtt := scaleDuration(10 * time.Millisecond) dropCallback := func(dir quicproxy.Direction, _, _ net.Addr, b []byte) bool { if dir != direction { return false } switch direction { case quicproxy.DirectionIncoming: clientTransport.WriteTo(b, serverTransport.Conn.LocalAddr()) case quicproxy.DirectionOutgoing: serverTransport.WriteTo(b, clientTransport.Conn.LocalAddr()) } return false } runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback) } func testMITMCorruptPackets(t *testing.T, direction quicproxy.Direction) { serverTransport, clientTransport := getTransportsForMITMTest(t) rtt := scaleDuration(5 * time.Millisecond) var numCorrupted atomic.Int32 dropCallback := func(dir quicproxy.Direction, _, _ net.Addr, b []byte) bool { if dir != direction { return false } isLongHeaderPacket := wire.IsLongHeaderPacket(b[0]) // corrupt 20% of long header packets and 5% of short header packets if isLongHeaderPacket && mrand.IntN(4) != 0 { return false } if !isLongHeaderPacket && mrand.IntN(20) != 0 { return false } numCorrupted.Add(1) pos := mrand.IntN(len(b)) b[pos] = byte(mrand.IntN(256)) switch direction { case quicproxy.DirectionIncoming: clientTransport.WriteTo(b, serverTransport.Conn.LocalAddr()) case quicproxy.DirectionOutgoing: serverTransport.WriteTo(b, clientTransport.Conn.LocalAddr()) } return true } runMITMTest(t, serverTransport, clientTransport, rtt, dropCallback) t.Logf("corrupted %d packets", numCorrupted.Load()) require.NotZero(t, int(numCorrupted.Load())) } func runMITMTest(t *testing.T, serverTr, clientTr *quic.Transport, rtt time.Duration, dropCb quicproxy.DropCallback) { ln, err := serverTr.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, DropPacket: dropCb, } require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(time.Second)) defer cancel() conn, err := clientTr.Dial(ctx, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := ln.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") str, err := conn.OpenStreamSync(ctx) require.NoError(t, err) clientErrChan := make(chan error, 1) go func() { _, err := str.Write(PRData) clientErrChan <- err str.Close() }() serverStr, err := serverConn.AcceptStream(ctx) require.NoError(t, err) serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) if _, err := io.Copy(serverStr, serverStr); err != nil { serverErrChan <- err return } serverStr.Close() }() require.NoError(t, <-serverErrChan) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) select { case err := <-clientErrChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } select { case err := <-serverErrChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestMITMForgedVersionNegotiationPacket(t *testing.T) { serverTransport, clientTransport := getTransportsForMITMTest(t) rtt := scaleDuration(10 * time.Millisecond) const supportedVersion protocol.Version = 42 var once sync.Once delayCb := func(dir quicproxy.Direction, _, _ net.Addr, raw []byte) time.Duration { if dir != quicproxy.DirectionIncoming { return rtt / 2 } once.Do(func() { hdr, _, _, err := wire.ParsePacket(raw) if err != nil { panic("failed to parse packet: " + err.Error()) } // create fake version negotiation packet with a fake supported version packet := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), []protocol.Version{supportedVersion}, ) if _, err := serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr()); err != nil { panic("failed to write packet: " + err.Error()) } }) return rtt / 2 } err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb) var vnErr *quic.VersionNegotiationError require.ErrorAs(t, err, &vnErr) require.Contains(t, vnErr.Theirs, supportedVersion) // might contain greased versions } // times out, because client doesn't accept subsequent real retry packets from server // as it has already accepted a retry. // TODO: determine behavior when server does not send Retry packets func TestMITMForgedRetryPacket(t *testing.T) { serverTransport, clientTransport := getTransportsForMITMTest(t) serverTransport.VerifySourceAddress = func(net.Addr) bool { return true } rtt := scaleDuration(10 * time.Millisecond) var once sync.Once delayCb := func(dir quicproxy.Direction, _, _ net.Addr, raw []byte) time.Duration { hdr, _, _, err := wire.ParsePacket(raw) if err != nil { panic("failed to parse packet: " + err.Error()) } if dir == quicproxy.DirectionIncoming && hdr.Type == protocol.PacketTypeInitial { once.Do(func() { fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) if _, err := serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr()); err != nil { panic("failed to write packet: " + err.Error()) } }) } return rtt / 2 } err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb) var nerr net.Error require.ErrorAs(t, err, &nerr) require.True(t, nerr.Timeout()) } func TestMITMForgedInitialPacket(t *testing.T) { serverTransport, clientTransport := getTransportsForMITMTest(t) rtt := scaleDuration(10 * time.Millisecond) var once sync.Once delayCb := func(dir quicproxy.Direction, _, _ net.Addr, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { hdr, _, _, err := wire.ParsePacket(raw) if err != nil { panic("failed to parse packet: " + err.Error()) } if hdr.Type != protocol.PacketTypeInitial { return 0 } once.Do(func() { initialPacket := testutils.ComposeInitialPacket( hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, nil, protocol.PerspectiveServer, hdr.Version, ) if _, err := serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()); err != nil { panic("failed to write packet: " + err.Error()) } }) } return rtt / 2 } err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb) var nerr net.Error require.ErrorAs(t, err, &nerr) require.True(t, nerr.Timeout()) } func TestMITMForgedInitialPacketWithAck(t *testing.T) { serverTransport, clientTransport := getTransportsForMITMTest(t) rtt := scaleDuration(10 * time.Millisecond) var once sync.Once delayCb := func(dir quicproxy.Direction, _, _ net.Addr, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { hdr, _, _, err := wire.ParsePacket(raw) if err != nil { panic("failed to parse packet: " + err.Error()) } if hdr.Type != protocol.PacketTypeInitial { return 0 } once.Do(func() { // Fake Initial with ACK for packet 2 (unsent) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} initialPacket := testutils.ComposeInitialPacket( hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version, ) if _, err := serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()); err != nil { panic("failed to write packet: " + err.Error()) } }) } return rtt / 2 } err := runMITMTestSuccessful(t, serverTransport, clientTransport, delayCb) var transportErr *quic.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, quic.ProtocolViolation, transportErr.ErrorCode) require.Contains(t, transportErr.ErrorMessage, "received ACK for an unsent packet") } func runMITMTestSuccessful(t *testing.T, serverTransport, clientTransport *quic.Transport, delayCb quicproxy.DelayCallback) error { t.Helper() ln, err := serverTransport.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer ln.Close() proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: delayCb, } require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) defer cancel() _, err = clientTransport.Dial(ctx, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) require.Error(t, err) return err } quic-go-0.59.0/integrationtests/self/mtu_test.go000066400000000000000000000122451513066070600217130ustar00rootroot00000000000000package self_test import ( "bytes" "context" "fmt" "io" "net" "sync" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func TestInitialPacketSize(t *testing.T) { server := newUDPConnLocalhost(t) client := newUDPConnLocalhost(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan struct{}) go func() { defer close(done) quic.Dial(ctx, client, server.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{ InitialPacketSize: 1337, })) }() buf := make([]byte, 2000) n, _, err := server.ReadFrom(buf) require.NoError(t, err) require.Equal(t, 1337, n) cancel() <-done } func TestPathMTUDiscovery(t *testing.T) { rtt := scaleDuration(5 * time.Millisecond) const mtu = 1400 ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{ InitialPacketSize: 1234, DisablePathMTUDiscovery: true, EnableDatagrams: true, }), ) require.NoError(t, err) defer ln.Close() serverErrChan := make(chan error, 1) go func() { conn, err := ln.Accept(context.Background()) if err != nil { serverErrChan <- err return } str, err := conn.AcceptStream(context.Background()) if err != nil { serverErrChan <- err return } defer str.Close() if _, err := io.Copy(str, str); err != nil { serverErrChan <- err return } }() var mx sync.Mutex var maxPacketSizeServer int var clientPacketSizes []int proxy := &quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, DropPacket: func(dir quicproxy.Direction, _, _ net.Addr, packet []byte) bool { if len(packet) > mtu { return true } mx.Lock() defer mx.Unlock() switch dir { case quicproxy.DirectionIncoming: clientPacketSizes = append(clientPacketSizes, len(packet)) case quicproxy.DirectionOutgoing: if len(packet) > maxPacketSizeServer { maxPacketSizeServer = len(packet) } } return false }, } require.NoError(t, proxy.Start()) defer proxy.Close() // Make sure to use v4-only socket here. // We can't reliably set the DF bit on dual-stack sockets on older versions of macOS (before Sequoia). tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer tr.Close() var eventRecorder events.Recorder conn, err := tr.Dial( context.Background(), proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{ InitialPacketSize: protocol.MinInitialPacketSize, EnableDatagrams: true, Tracer: newTracer(&eventRecorder), }), ) require.NoError(t, err) defer conn.CloseWithError(0, "") err = conn.SendDatagram(make([]byte, 2000)) require.Error(t, err) var datagramErr *quic.DatagramTooLargeError require.ErrorAs(t, err, &datagramErr) initialMaxDatagramSize := datagramErr.MaxDatagramPayloadSize str, err := conn.OpenStream() require.NoError(t, err) clientErrChan := make(chan error, 1) go func() { data, err := io.ReadAll(str) if err != nil { clientErrChan <- err return } if !bytes.Equal(data, PRDataLong) { clientErrChan <- fmt.Errorf("echoed data doesn't match: %x", data) return } clientErrChan <- nil }() _, err = str.Write(PRDataLong) require.NoError(t, err) str.Close() select { case err := <-clientErrChan: require.NoError(t, err) case err := <-serverErrChan: t.Fatalf("server error: %v", err) case <-time.After(20 * time.Second): t.Fatal("timeout") } err = conn.SendDatagram(make([]byte, 2000)) require.Error(t, err) require.ErrorAs(t, err, &datagramErr) finalMaxDatagramSize := datagramErr.MaxDatagramPayloadSize mx.Lock() defer mx.Unlock() require.NotEmpty(t, eventRecorder.Events(qlog.MTUUpdated{})) var mtus []int for _, ev := range eventRecorder.Events(qlog.MTUUpdated{}) { mtus = append(mtus, ev.(qlog.MTUUpdated).Value) } maxPacketSizeClient := mtus[len(mtus)-1] t.Logf("max client packet size: %d, MTU: %d", maxPacketSizeClient, mtu) t.Logf("max datagram size: initial: %d, final: %d", initialMaxDatagramSize, finalMaxDatagramSize) t.Logf("max server packet size: %d, MTU: %d", maxPacketSizeServer, mtu) require.GreaterOrEqual(t, maxPacketSizeClient, mtu-25) const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead require.GreaterOrEqual(t, int(initialMaxDatagramSize), protocol.MinInitialPacketSize-maxDiff) require.GreaterOrEqual(t, int(finalMaxDatagramSize), maxPacketSizeClient-maxDiff) // MTU discovery was disabled on the server side require.Equal(t, 1234, maxPacketSizeServer) var numPacketsLargerThanDiscoveredMTU int for _, s := range clientPacketSizes { if s > maxPacketSizeClient { numPacketsLargerThanDiscoveredMTU++ } } // The client shouldn't have sent any packets larger than the MTU it discovered, // except for at most one MTU probe packet. require.LessOrEqual(t, numPacketsLargerThanDiscoveredMTU, 1) } quic-go-0.59.0/integrationtests/self/multiplex_test.go000066400000000000000000000221411513066070600231250ustar00rootroot00000000000000package self_test import ( "bytes" "context" "crypto/rand" "errors" "fmt" "io" mrand "math/rand/v2" "net" "runtime" "testing" "time" "github.com/quic-go/quic-go" "github.com/stretchr/testify/require" ) func runMultiplexTestServer(t *testing.T, ln *quic.Listener) { t.Helper() for { conn, err := ln.Accept(context.Background()) if err != nil { return } str, err := conn.OpenUniStream() require.NoError(t, err) go func() { defer str.Close() _, err = str.Write(PRData) require.NoError(t, err) }() t.Cleanup(func() { conn.CloseWithError(0, "") }) } } func dialAndReceiveData(tr *quic.Transport, addr net.Addr) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := tr.Dial(ctx, addr, getTLSClientConfig(), getQuicConfig(nil)) if err != nil { return fmt.Errorf("error dialing: %w", err) } str, err := conn.AcceptUniStream(ctx) if err != nil { return fmt.Errorf("error accepting stream: %w", err) } data, err := io.ReadAll(str) if err != nil { return fmt.Errorf("error reading data: %w", err) } if !bytes.Equal(data, PRData) { return fmt.Errorf("data mismatch: got %q, expected %q", data, PRData) } return nil } func TestMultiplexesConnectionsToSameServer(t *testing.T) { server, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() go runMultiplexTestServer(t, server) tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr) defer tr.Close() errChan1 := make(chan error, 1) go func() { errChan1 <- dialAndReceiveData(tr, server.Addr()) }() errChan2 := make(chan error, 1) go func() { errChan2 <- dialAndReceiveData(tr, server.Addr()) }() select { case err := <-errChan1: require.NoError(t, err, "error dialing server 1") case <-time.After(5 * time.Second): t.Error("timeout waiting for done1 to close") } select { case err := <-errChan2: require.NoError(t, err) case <-time.After(5 * time.Second): t.Error("timeout waiting for done2 to close") } } func TestMultiplexingToDifferentServers(t *testing.T) { server1, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server1.Close() go runMultiplexTestServer(t, server1) server2, err := quic.Listen(newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server2.Close() go runMultiplexTestServer(t, server2) tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr) defer tr.Close() errChan1 := make(chan error, 1) go func() { errChan1 <- dialAndReceiveData(tr, server1.Addr()) }() errChan2 := make(chan error, 1) go func() { errChan2 <- dialAndReceiveData(tr, server2.Addr()) }() select { case err := <-errChan1: require.NoError(t, err, "error dialing server 1") case <-time.After(5 * time.Second): t.Error("timeout waiting for done1 to close") } select { case err := <-errChan2: require.NoError(t, err, "error dialing server 2") case <-time.After(5 * time.Second): t.Error("timeout waiting for done2 to close") } } func TestMultiplexingConnectToSelf(t *testing.T) { tr := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr) defer tr.Close() server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() go runMultiplexTestServer(t, server) errChan := make(chan error, 1) go func() { errChan <- dialAndReceiveData(tr, server.Addr()) }() select { case err := <-errChan: require.NoError(t, err, "error dialing server") case <-time.After(5 * time.Second): t.Error("timeout waiting for connection to close") } } func TestMultiplexingServerAndClientOnSameConn(t *testing.T) { if runtime.GOOS == "linux" { t.Skip("This test requires setting of iptables rules on Linux, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.") } tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr1) defer tr1.Close() server1, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server1.Close() tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)} addTracer(tr2) defer tr2.Close() server2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server2.Close() done1 := make(chan struct{}) go func() { defer close(done1) dialAndReceiveData(tr2, server1.Addr()) }() done2 := make(chan struct{}) go func() { defer close(done2) dialAndReceiveData(tr1, server2.Addr()) }() select { case <-done1: case <-time.After(5 * time.Second): t.Error("timeout waiting for done1 to close") } select { case <-done2: case <-time.After(time.Second): t.Error("timeout waiting for done2 to close") } } func TestMultiplexingNonQUICPackets(t *testing.T) { const numPackets = 100 tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer tr1.Close() addTracer(tr1) server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)} defer tr2.Close() addTracer(tr2) type nonQUICPacket struct { b []byte addr net.Addr err error } rcvdPackets := make(chan nonQUICPacket, numPackets) receiveCtx, receiveCancel := context.WithCancel(context.Background()) defer receiveCancel() // start receiving non-QUIC packets go func() { for { b := make([]byte, 1024) n, addr, err := tr2.ReadNonQUICPacket(receiveCtx, b) if errors.Is(err, context.Canceled) { return } rcvdPackets <- nonQUICPacket{b: b[:n], addr: addr, err: err} } }() ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) defer cancel2() conn, err := tr2.Dial(ctx2, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) defer conn.CloseWithError(0, "") serverConn, err := server.Accept(ctx2) require.NoError(t, err) serverStr, err := serverConn.OpenUniStream() require.NoError(t, err) // send a non-QUIC packet every 100µs const packetLen = 128 errChanNonQUIC := make(chan error, 1) sendNonQUICPacket := make(chan struct{}, 1) go func() { var seed [32]byte rand.Read(seed[:]) random := mrand.NewChaCha8(seed) defer close(errChanNonQUIC) var sentPackets int for range sendNonQUICPacket { b := make([]byte, packetLen) random.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) // The first sendmsg call on a new UDP socket sometimes errors on Linux. // It's not clear why this happens. // See https://github.com/golang/go/issues/63322. if err != nil && sentPackets == 0 && runtime.GOOS == "linux" && isPermissionError(err) { _, err = tr1.WriteTo(b, tr2.Conn.LocalAddr()) } if err != nil { errChanNonQUIC <- err return } sentPackets++ } }() sendQUICPacket := make(chan struct{}, 1) errChanQUIC := make(chan error, 1) var dataSent []byte go func() { defer close(errChanQUIC) defer serverStr.Close() var seed [32]byte rand.Read(seed[:]) random := mrand.NewChaCha8(seed) for range sendQUICPacket { b := make([]byte, 1024) random.Read(b) if _, err := serverStr.Write(b); err != nil { errChanQUIC <- err return } dataSent = append(dataSent, b...) } }() dataChan := make(chan []byte, 1) readErr := make(chan error, 1) go func() { str, err := conn.AcceptUniStream(ctx2) if err != nil { readErr <- err return } data, err := io.ReadAll(str) if err != nil { readErr <- err return } dataChan <- data }() ticker := time.NewTicker(scaleDuration(200 * time.Microsecond)) defer ticker.Stop() for range numPackets { sendNonQUICPacket <- struct{}{} sendQUICPacket <- struct{}{} <-ticker.C } close(sendNonQUICPacket) close(sendQUICPacket) select { case err := <-errChanNonQUIC: require.NoError(t, err, "error sending non-QUIC packets") case <-time.After(time.Second): t.Fatalf("timeout waiting for non-QUIC packets to be sent") } select { case err := <-errChanQUIC: require.NoError(t, err, "error sending QUIC packets") case <-time.After(time.Second): t.Fatalf("timeout waiting for QUIC packets to be sent") } select { case err := <-readErr: require.NoError(t, err, "error reading stream data") case dataRcvd := <-dataChan: require.Equal(t, dataSent, dataRcvd, "stream data mismatch") case <-time.After(time.Second): t.Fatalf("timeout waiting for stream data to be read") } // make sure we don't overflow the capacity of the channel require.LessOrEqual(t, numPackets, cap(rcvdPackets), "too many non-QUIC packets sent: %d > %d", numPackets, cap(rcvdPackets)) // now receive these packets minExpected := numPackets * 4 / 5 timeout := time.After(time.Second) var counter int for counter < minExpected { select { case p := <-rcvdPackets: require.Equal(t, tr1.Conn.LocalAddr(), p.addr, "non-QUIC packet received from wrong address") require.Equal(t, packetLen, len(p.b), "non-QUIC packet incorrect length") require.NoError(t, p.err, "error receiving non-QUIC packet") counter++ case <-timeout: t.Fatalf("didn't receive enough non-QUIC packets: %d < %d", counter, minExpected) } } } quic-go-0.59.0/integrationtests/self/nat_rebinding_test.go000066400000000000000000000061731513066070600237140ustar00rootroot00000000000000package self_test import ( "context" "fmt" "io" "net" "os" "sync" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/require" ) func TestNATRebinding(t *testing.T) { tr, tracer := newPacketTracer() tlsConf := getTLSConfig() f, err := os.Create("keylog.txt") require.NoError(t, err) defer f.Close() tlsConf.KeyLogWriter = f server, err := quic.Listen( newUDPConnLocalhost(t), tlsConf, getQuicConfig(&quic.Config{ Tracer: func(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer server.Close() newPath := newUDPConnLocalhost(t) clientUDPConn := newUDPConnLocalhost(t) oldPathRTT := scaleDuration(10 * time.Millisecond) newPathRTT := scaleDuration(20 * time.Millisecond) proxy := quicproxy.Proxy{ ServerAddr: server.Addr().(*net.UDPAddr), Conn: newUDPConnLocalhost(t), } var mx sync.Mutex var switchedPath bool var dataTransferred int proxy.DelayPacket = func(dir quicproxy.Direction, _, _ net.Addr, b []byte) time.Duration { mx.Lock() defer mx.Unlock() if dir == quicproxy.DirectionOutgoing { dataTransferred += len(b) if dataTransferred > len(PRData)/3 { if !switchedPath { if err := proxy.SwitchConn(clientUDPConn.LocalAddr().(*net.UDPAddr), newPath); err != nil { panic(fmt.Sprintf("failed to switch connection: %s", err)) } switchedPath = true } } } if switchedPath { return newPathRTT } return oldPathRTT } require.NoError(t, proxy.Start()) defer proxy.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, clientUDPConn, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") go func() { str, err := serverConn.OpenUniStream() require.NoError(t, err) go func() { defer str.Close() _, err = str.Write(PRData) require.NoError(t, err) }() }() str, err := conn.AcceptUniStream(ctx) require.NoError(t, err) str.SetReadDeadline(time.Now().Add(5 * time.Second)) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, PRData, data) conn.CloseWithError(0, "") // check that a PATH_CHALLENGE was sent var pathChallenge [8]byte var foundPathChallenge bool for _, p := range tr.getSentShortHeaderPackets() { for _, f := range p.frames { switch fr := f.Frame.(type) { case *qlog.PathChallengeFrame: pathChallenge = fr.Data foundPathChallenge = true } } } require.True(t, foundPathChallenge) // check that a PATH_RESPONSE with the correct data was received var foundPathResponse bool for _, p := range tr.getRcvdShortHeaderPackets() { for _, f := range p.frames { switch fr := f.Frame.(type) { case *qlog.PathResponseFrame: require.Equal(t, pathChallenge, fr.Data) foundPathResponse = true } } } require.True(t, foundPathResponse) } quic-go-0.59.0/integrationtests/self/packetization_test.go000066400000000000000000000166401513066070600237560ustar00rootroot00000000000000package self_test import ( "context" "fmt" "io" "net" "os" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestACKBundling(t *testing.T) { const numMsg = 100 serverCounter, serverTracer := newPacketTracer() server, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{ DisablePathMTUDiscovery: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return serverTracer }, }), ) require.NoError(t, err) defer server.Close() proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: server.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return 5 * time.Millisecond }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientCounter, clientTracer := newPacketTracer() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial( ctx, newUDPConnLocalhost(t), proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{ DisablePathMTUDiscovery: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return clientTracer }, }), ) require.NoError(t, err) defer conn.CloseWithError(0, "") serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) conn, err := server.Accept(context.Background()) if err != nil { serverErrChan <- fmt.Errorf("accept failed: %w", err) return } str, err := conn.AcceptStream(context.Background()) if err != nil { serverErrChan <- fmt.Errorf("accept stream failed: %w", err) return } b := make([]byte, 1) // Echo every byte received from the client. for { if _, err := str.Read(b); err != nil { break } _, err = str.Write(b) if err != nil { serverErrChan <- fmt.Errorf("write failed: %w", err) return } } }() str, err := conn.OpenStreamSync(context.Background()) require.NoError(t, err) b := make([]byte, 1) // Send numMsg 1-byte messages. for i := 0; i < numMsg; i++ { _, err = str.Write([]byte{uint8(i)}) require.NoError(t, err) _, err = str.Read(b) require.NoError(t, err) require.Equal(t, uint8(i), b[0]) } require.NoError(t, conn.CloseWithError(0, "")) require.NoError(t, <-serverErrChan) countBundledPackets := func(packets []packet) (numBundled int) { for _, p := range packets { var hasAck, hasStreamFrame bool for _, f := range p.frames { switch f.Frame.(type) { case *qlog.AckFrame: hasAck = true case *qlog.StreamFrame: hasStreamFrame = true } } if hasAck && hasStreamFrame { numBundled++ } } return } numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets()) numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets()) t.Logf("bundled incoming packets: %d / %d", numBundledIncoming, numMsg) t.Logf("bundled outgoing packets: %d / %d", numBundledOutgoing, numMsg) require.LessOrEqual(t, numBundledIncoming, numMsg) require.Greater(t, numBundledIncoming, numMsg*9/10) require.LessOrEqual(t, numBundledOutgoing, numMsg) require.Greater(t, numBundledOutgoing, numMsg*9/10) } func TestStreamDataBlocked(t *testing.T) { testConnAndStreamDataBlocked(t, true, false) } func TestConnDataBlocked(t *testing.T) { testConnAndStreamDataBlocked(t, false, true) } func testConnAndStreamDataBlocked(t *testing.T, limitStream, limitConn bool) { const window = 100 const numBatches = 3 initialStreamWindow := uint64(quicvarint.Max) initialConnWindow := uint64(quicvarint.Max) if limitStream { initialStreamWindow = window } if limitConn { initialConnWindow = window } rtt := scaleDuration(5 * time.Millisecond) ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{ InitialStreamReceiveWindow: initialStreamWindow, InitialConnectionReceiveWindow: initialConnWindow, }), ) require.NoError(t, err) defer ln.Close() proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, } require.NoError(t, proxy.Start()) defer proxy.Close() counter, tracer := newPacketTracer() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial( ctx, newUDPConnLocalhost(t), proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{ Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) serverConn, err := ln.Accept(ctx) require.NoError(t, err) str, err := conn.OpenUniStreamSync(ctx) require.NoError(t, err) // Stream data is consumed (almost) immediately, so flow-control window auto-tuning kicks in. // The window size is doubled for every batch. var windowSizes []protocol.ByteCount for i := range numBatches { windowSizes = append(windowSizes, window< highestSeen { highestSeen = pn } } t.Logf("Smoothed RTT: %s", conn.ConnectionStats().SmoothedRTT) assert.GreaterOrEqual(t, conn.ConnectionStats().SmoothedRTT, rtt*9/10) assert.LessOrEqual(t, conn.ConnectionStats().SmoothedRTT, rtt*11/10) t.Logf("received %d short header packets, detected %d reorderings", len(packetNumbers), reorderings) assert.GreaterOrEqual(t, reorderings, 20, "expected at least 20 reorderings") }) }) } } quic-go-0.59.0/integrationtests/self/self_go124_test.go000066400000000000000000000002051513066070600227440ustar00rootroot00000000000000//go:build !go1.25 package self_test import "crypto/tls" func getCurveID(connState tls.ConnectionState) tls.CurveID { return 0 } quic-go-0.59.0/integrationtests/self/self_go125_test.go000066400000000000000000000002241513066070600227460ustar00rootroot00000000000000//go:build go1.25 package self_test import "crypto/tls" func getCurveID(connState tls.ConnectionState) tls.CurveID { return connState.CurveID } quic-go-0.59.0/integrationtests/self/self_suite_linux_test.go000066400000000000000000000006521513066070600244660ustar00rootroot00000000000000//go:build linux package self_test import ( "errors" "os" "golang.org/x/sys/unix" ) // The first sendmsg call on a new UDP socket sometimes errors on Linux. // It's not clear why this happens. // See https://github.com/golang/go/issues/63322. func isPermissionError(err error) bool { var serr *os.SyscallError if errors.As(err, &serr) { return serr.Syscall == "sendmsg" && serr.Err == unix.EPERM } return false } quic-go-0.59.0/integrationtests/self/self_suite_others_test.go000066400000000000000000000001371513066070600246310ustar00rootroot00000000000000//go:build !linux package self_test func isPermissionError(err error) bool { return false } quic-go-0.59.0/integrationtests/self/self_test.go000066400000000000000000000213671513066070600220440ustar00rootroot00000000000000package self_test import ( "context" "crypto/tls" "crypto/x509" "flag" "fmt" "io" "math/rand/v2" "net" "os" "runtime" "strconv" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/integrationtests/tools" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) const alpn = tools.ALPN const ( dataLen = 500 * 1024 // 500 KB dataLenLong = 50 * 1024 * 1024 // 50 MB ) var ( // PRData contains dataLen bytes of pseudo-random data. PRData = GeneratePRData(dataLen) // PRDataLong contains dataLenLong bytes of pseudo-random data. PRDataLong = GeneratePRData(dataLenLong) ) // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator func GeneratePRData(l int) []byte { res := make([]byte, l) seed := uint64(1) for i := 0; i < l; i++ { seed = seed * 48271 % 2147483647 res[i] = byte(seed) } return res } var ( version quic.Version enableQlog bool tlsConfig *tls.Config tlsConfigLongChain *tls.Config tlsClientConfig *tls.Config tlsClientConfigWithoutServerName *tls.Config ) func init() { ca, caPrivateKey, err := tools.GenerateCA() if err != nil { panic(err) } leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey) if err != nil { panic(err) } tlsConfig = &tls.Config{ Certificates: []tls.Certificate{{ Certificate: [][]byte{leafCert.Raw}, PrivateKey: leafPrivateKey, }}, NextProtos: []string{alpn}, } tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey) if err != nil { panic(err) } tlsConfigLongChain = tlsConfLongChain root := x509.NewCertPool() root.AddCert(ca) tlsClientConfig = &tls.Config{ ServerName: "localhost", RootCAs: root, NextProtos: []string{alpn}, } tlsClientConfigWithoutServerName = &tls.Config{ RootCAs: root, NextProtos: []string{alpn}, } } func getTLSConfig() *tls.Config { return tlsConfig.Clone() } func getTLSConfigWithLongCertChain() *tls.Config { return tlsConfigLongChain.Clone() } func getTLSClientConfig() *tls.Config { return tlsClientConfig.Clone() } func getTLSClientConfigWithoutServerName() *tls.Config { return tlsClientConfigWithoutServerName.Clone() } type multiplexedRecorder struct { Recorders []qlogwriter.Recorder } var _ qlogwriter.Recorder = &multiplexedRecorder{} func (r *multiplexedRecorder) Close() error { for _, recorder := range r.Recorders { recorder.Close() } return nil } func (r *multiplexedRecorder) RecordEvent(ev qlogwriter.Event) { for _, recorder := range r.Recorders { recorder.RecordEvent(ev) } } type multiplexedTrace struct { Traces []qlogwriter.Trace } var _ qlogwriter.Trace = &multiplexedTrace{} func (t *multiplexedTrace) AddProducer() qlogwriter.Recorder { recorders := make([]qlogwriter.Recorder, 0, len(t.Traces)) for _, tr := range t.Traces { recorders = append(recorders, tr.AddProducer()) } return &multiplexedRecorder{Recorders: recorders} } func (t *multiplexedTrace) SupportsSchemas(schema string) bool { return true } func getQuicConfig(conf *quic.Config) *quic.Config { if conf == nil { conf = &quic.Config{} } else { conf = conf.Clone() } if !enableQlog { return conf } if conf.Tracer == nil { conf.Tracer = func(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { return tools.NewQlogConnectionTracer(os.Stdout)(ctx, isClient, connID) } return conf } origTracer := conf.Tracer conf.Tracer = func(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { tr := origTracer(ctx, isClient, connID) qlogger := tools.NewQlogConnectionTracer(os.Stdout)(ctx, isClient, connID) if tr == nil { return qlogger } return &multiplexedTrace{Traces: []qlogwriter.Trace{tr, qlogger}} } return conf } func addTracer(tr *quic.Transport) { if !enableQlog { return } if tr.Tracer == nil { tr.Tracer = tools.QlogTracer(os.Stdout).AddProducer() return } origTracer := tr.Tracer tr.Tracer = &multiplexedRecorder{ Recorders: []qlogwriter.Recorder{origTracer, tools.QlogTracer(os.Stdout).AddProducer()}, } } func newUDPConnLocalhost(t testing.TB) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) return conn } func TestMain(m *testing.M) { var versionParam string flag.StringVar(&versionParam, "version", "1", "QUIC version") flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") flag.Parse() switch versionParam { case "1": version = quic.Version1 case "2": version = quic.Version2 default: fmt.Printf("unknown QUIC version: %s\n", versionParam) os.Exit(1) } fmt.Printf("using QUIC version: %s\n", version) os.Exit(m.Run()) } func scaleDuration(d time.Duration) time.Duration { scaleFactor := 1 if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set scaleFactor = f } if scaleFactor == 0 { panic("TIMESCALE_FACTOR is 0") } return time.Duration(scaleFactor) * d } func newTracer(tracer qlogwriter.Recorder) func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: tracer} } } type packet struct { time time.Time hdr qlog.PacketHeader frames []qlog.Frame } type packetCounter struct { recorder *events.Recorder } func (t *packetCounter) getSentShortHeaderPackets() []packet { var sentShortHdr []packet for _, ev := range t.recorder.EventsWithTime(qlog.PacketSent{}) { e := ev.Event.(qlog.PacketSent) if e.Header.PacketType != qlog.PacketType1RTT { continue } sentShortHdr = append(sentShortHdr, packet{time: ev.Time, hdr: e.Header, frames: e.Frames}) } return sentShortHdr } func (t *packetCounter) getRcvdLongHeaderPackets() []packet { var rcvdLongHdr []packet for _, ev := range t.recorder.EventsWithTime(qlog.PacketReceived{}) { e := ev.Event.(qlog.PacketReceived) if e.Header.PacketType == qlog.PacketType1RTT { continue } rcvdLongHdr = append(rcvdLongHdr, packet{time: ev.Time, hdr: e.Header, frames: e.Frames}) } return rcvdLongHdr } func (t *packetCounter) getRcvd0RTTPacketNumbers() []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range t.getRcvdLongHeaderPackets() { if p.hdr.PacketType == qlog.PacketType0RTT { zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) } } return zeroRTTPackets } func (t *packetCounter) getRcvdShortHeaderPackets() []packet { var rcvdShortHdr []packet for _, ev := range t.recorder.EventsWithTime(qlog.PacketReceived{}) { e := ev.Event.(qlog.PacketReceived) if e.Header.PacketType != qlog.PacketType1RTT { continue } rcvdShortHdr = append(rcvdShortHdr, packet{time: ev.Time, hdr: e.Header, frames: e.Frames}) } return rcvdShortHdr } func newPacketTracer() (*packetCounter, qlogwriter.Trace) { c := &packetCounter{recorder: &events.Recorder{}} return c, &events.Trace{Recorder: c.recorder} } type readerWithTimeout struct { io.Reader Timeout time.Duration } func (r *readerWithTimeout) Read(p []byte) (n int, err error) { done := make(chan struct{}) go func() { defer close(done) n, err = r.Reader.Read(p) }() select { case <-done: return n, err case <-time.After(r.Timeout): return 0, fmt.Errorf("read timeout after %s", r.Timeout) } } func randomDuration(min, max time.Duration) time.Duration { return min + time.Duration(rand.IntN(int(max-min))) } // containsPacketType checks if a packet contains a long header packet of the specified type. // It correctly handles coalesced packets. func containsPacketType(data []byte, packetType protocol.PacketType) bool { for len(data) > 0 { if !wire.IsLongHeaderPacket(data[0]) { return false } hdr, _, rest, err := wire.ParsePacket(data) if err != nil { return false } if hdr.Type == packetType { return true } data = rest } return false } // addDialCallback explicitly adds the http3.Transport's Dial callback. // This is needed since dialing on dual-stack sockets is flaky on macOS, // see https://github.com/golang/go/issues/67226. func addDialCallback(t *testing.T, tr *http3.Transport) { t.Helper() if runtime.GOOS != "darwin" { return } require.Nil(t, tr.Dial) tr.Dial = func(ctx context.Context, addr string, tlsConf *tls.Config, conf *quic.Config) (*quic.Conn, error) { a, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } return quic.DialEarly(ctx, newUDPConnLocalhost(t), a, tlsConf, conf) } } quic-go-0.59.0/integrationtests/self/simnet_helper_test.go000066400000000000000000000043671513066070600237520ustar00rootroot00000000000000package self_test import ( "net" "testing" "time" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) func newSimnetLink(t *testing.T, rtt time.Duration) (client, server *simnet.SimConn, close func(t *testing.T)) { t.Helper() return newSimnetLinkWithRouter(t, rtt, &simnet.PerfectRouter{}) } func newSimnetLinkWithRouter(t *testing.T, rtt time.Duration, router simnet.Router) (client, server *simnet.SimConn, close func(t *testing.T)) { t.Helper() n := &simnet.Simnet{Router: router} settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} clientPacketConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}, settings) serverPacketConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}, settings) require.NoError(t, n.Start()) return clientPacketConn, serverPacketConn, func(t *testing.T) { require.NoError(t, clientPacketConn.Close()) require.NoError(t, serverPacketConn.Close()) require.NoError(t, n.Close()) } } type droppingRouter struct { simnet.PerfectRouter Drop func(simnet.Packet) bool } func (d *droppingRouter) SendPacket(p simnet.Packet) error { if d.Drop(p) { return nil } return d.PerfectRouter.SendPacket(p) } type callbackRouter struct { simnet.Router OnSendPacket func(simnet.Packet) } func (c *callbackRouter) SendPacket(p simnet.Packet) error { c.OnSendPacket(p) return c.Router.SendPacket(p) } type direction uint8 const ( directionUnknown = iota directionToClient directionToServer directionBoth ) func (d direction) String() string { switch d { case directionToClient: return "to client" case directionToServer: return "to server" case directionBoth: return "both" } return "unknown" } var _ simnet.Router = &droppingRouter{} type directionAwareDroppingRouter struct { simnet.PerfectRouter ClientAddr, ServerAddr *net.UDPAddr Drop func(direction direction, p simnet.Packet) bool } func (d *directionAwareDroppingRouter) SendPacket(p simnet.Packet) error { var dir direction switch p.To.String() { case d.ClientAddr.String(): dir = directionToClient case d.ServerAddr.String(): dir = directionToServer default: dir = directionUnknown } if d.Drop(dir, p) { return nil } return d.PerfectRouter.SendPacket(p) } quic-go-0.59.0/integrationtests/self/stateless_reset_test.go000066400000000000000000000063651513066070600243250ustar00rootroot00000000000000package self_test import ( "context" "crypto/rand" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) func TestStatelessResets(t *testing.T) { t.Run("zero-length connection IDs", func(t *testing.T) { testStatelessReset(t, 0) }) t.Run("10 byte connection IDs", func(t *testing.T) { testStatelessReset(t, 10) }) } func testStatelessReset(t *testing.T, connIDLen int) { synctest.Test(t, func(t *testing.T) { var drop atomic.Bool clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, time.Millisecond, &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, ) defer closeFn(t) var statelessResetKey quic.StatelessResetKey rand.Read(statelessResetKey[:]) tr := &quic.Transport{ Conn: serverPacketConn, StatelessResetKey: &statelessResetKey, } defer tr.Close() ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) serverErr := make(chan error, 1) go func() { conn, err := ln.Accept(context.Background()) if err != nil { serverErr <- err return } str, err := conn.OpenStream() if err != nil { serverErr <- err return } _, err = str.Write([]byte("foobar")) if err != nil { serverErr <- err return } close(serverErr) }() var conn *quic.Conn if connIDLen > 0 { cl := &quic.Transport{ Conn: clientPacketConn, ConnectionIDLength: connIDLen, } defer cl.Close() var err error conn, err = cl.Dial( context.Background(), serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}), ) require.NoError(t, err) } else { conn, err = quic.Dial( context.Background(), clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}), ) require.NoError(t, err) } str, err := conn.AcceptStream(context.Background()) require.NoError(t, err) data := make([]byte, 6) _, err = str.Read(data) require.NoError(t, err) require.Equal(t, []byte("foobar"), data) // make sure that the CONNECTION_CLOSE is dropped drop.Store(true) require.NoError(t, ln.Close()) require.NoError(t, tr.Close()) require.NoError(t, <-serverErr) time.Sleep(100 * time.Millisecond) // We need to create a new Transport here, since the old one is still sending out // CONNECTION_CLOSE packets for (recently) closed connections). tr2 := &quic.Transport{ Conn: serverPacketConn, StatelessResetKey: &statelessResetKey, } defer tr2.Close() ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) drop.Store(false) // Trigger something (not too small) to be sent, so that we receive the stateless reset. // If the client already sent another packet, it might already have received a packet. _, serr := str.Write([]byte("Lorem ipsum dolor sit amet.")) if serr == nil { _, serr = str.Read([]byte{0}) } require.Error(t, serr) require.IsType(t, &quic.StatelessResetError{}, serr) require.NoError(t, ln2.Close()) }) } quic-go-0.59.0/integrationtests/self/stream_test.go000066400000000000000000000166561513066070600224130ustar00rootroot00000000000000package self_test import ( "bytes" "context" "fmt" "io" "testing" "time" "github.com/quic-go/quic-go" "golang.org/x/sync/errgroup" "github.com/stretchr/testify/require" ) func TestBidirectionalStreamMultiplexing(t *testing.T) { const numStreams = 75 runSendingPeer := func(conn *quic.Conn) error { g := new(errgroup.Group) for i := 0; i < numStreams; i++ { str, err := conn.OpenStreamSync(context.Background()) if err != nil { return err } data := GeneratePRData(50 * i) g.Go(func() error { if _, err := str.Write(data); err != nil { return err } return str.Close() }) g.Go(func() error { dataRead, err := io.ReadAll(str) if err != nil { return err } if !bytes.Equal(dataRead, data) { return fmt.Errorf("data mismatch: %q != %q", dataRead, data) } return nil }) } return g.Wait() } runReceivingPeer := func(conn *quic.Conn) error { g := new(errgroup.Group) for i := 0; i < numStreams; i++ { str, err := conn.AcceptStream(context.Background()) if err != nil { return err } g.Go(func() error { // shouldn't use io.Copy here // we should read from the stream as early as possible, to free flow control credit data, err := io.ReadAll(str) if err != nil { return err } if _, err := str.Write(data); err != nil { return err } return str.Close() }) } return g.Wait() } t.Run("client -> server", func(t *testing.T) { ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{ MaxIncomingStreams: 10, InitialStreamReceiveWindow: 10000, InitialConnectionReceiveWindow: 5000, }), ) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := quic.Dial( ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{InitialConnectionReceiveWindow: 2000}), ) require.NoError(t, err) conn, err := ln.Accept(ctx) require.NoError(t, err) errChan := make(chan error, 1) go func() { errChan <- runReceivingPeer(conn) }() require.NoError(t, runSendingPeer(client)) client.CloseWithError(0, "") select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): require.Fail(t, "timeout") } select { case <-conn.Context().Done(): case <-time.After(time.Second): require.Fail(t, "timeout") } }) t.Run("client <-> server", func(t *testing.T) { ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(&quic.Config{ MaxIncomingStreams: 30, InitialStreamReceiveWindow: 25000, InitialConnectionReceiveWindow: 50000, }), ) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := quic.Dial( ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{InitialConnectionReceiveWindow: 2000}), ) require.NoError(t, err) conn, err := ln.Accept(ctx) require.NoError(t, err) errChan1 := make(chan error, 1) errChan2 := make(chan error, 1) errChan3 := make(chan error, 1) errChan4 := make(chan error, 1) go func() { errChan1 <- runReceivingPeer(conn) }() go func() { errChan2 <- runSendingPeer(conn) }() go func() { errChan3 <- runReceivingPeer(client) }() go func() { errChan4 <- runSendingPeer(client) }() for _, ch := range []chan error{errChan1, errChan2, errChan3, errChan4} { select { case err := <-ch: require.NoError(t, err) case <-time.After(time.Second): require.Fail(t, "timeout") } } client.CloseWithError(0, "") select { case <-conn.Context().Done(): case <-time.After(time.Second): require.Fail(t, "timeout") } }) } func TestUnidirectionalStreams(t *testing.T) { const numStreams = 500 dataForStream := func(id uint64) []byte { return GeneratePRData(10 * int(id)) } runSendingPeer := func(conn *quic.Conn) error { g := new(errgroup.Group) for i := 0; i < numStreams; i++ { str, err := conn.OpenUniStreamSync(context.Background()) if err != nil { return err } g.Go(func() error { if _, err := str.Write(dataForStream(uint64(str.StreamID()))); err != nil { return err } return str.Close() }) } return g.Wait() } runReceivingPeer := func(conn *quic.Conn) error { g := new(errgroup.Group) for i := 0; i < numStreams; i++ { str, err := conn.AcceptUniStream(context.Background()) if err != nil { return err } g.Go(func() error { data, err := io.ReadAll(str) if err != nil { return err } if !bytes.Equal(data, dataForStream(uint64(str.StreamID()))) { return fmt.Errorf("data mismatch") } return nil }) } return g.Wait() } t.Run("client -> server", func(t *testing.T) { ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil), ) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) serverConn, err := ln.Accept(ctx) require.NoError(t, err) errChan := make(chan error, 1) go func() { errChan <- runSendingPeer(client) }() require.NoError(t, runReceivingPeer(serverConn)) serverConn.CloseWithError(0, "") select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): require.Fail(t, "timeout") } }) t.Run("server -> client", func(t *testing.T) { ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil), ) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) serverConn, err := ln.Accept(ctx) require.NoError(t, err) errChan := make(chan error, 1) go func() { errChan <- runSendingPeer(serverConn) }() require.NoError(t, runReceivingPeer(client)) client.CloseWithError(0, "") select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): require.Fail(t, "timeout") } }) t.Run("client <-> server", func(t *testing.T) { ln, err := quic.Listen( newUDPConnLocalhost(t), getTLSConfig(), getQuicConfig(nil), ) require.NoError(t, err) defer ln.Close() errChan1 := make(chan error, 1) errChan2 := make(chan error, 1) go func() { conn, err := ln.Accept(context.Background()) if err != nil { errChan1 <- err errChan2 <- err return } errChan1 <- runReceivingPeer(conn) errChan2 <- runSendingPeer(conn) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) errChan3 := make(chan error, 1) go func() { errChan3 <- runSendingPeer(client) }() require.NoError(t, runReceivingPeer(client)) for _, ch := range []chan error{errChan1, errChan2, errChan3} { select { case err := <-ch: require.NoError(t, err) case <-time.After(time.Second): require.Fail(t, "timeout") } } client.CloseWithError(0, "") }) } quic-go-0.59.0/integrationtests/self/timeout_test.go000066400000000000000000000326361513066070600226020ustar00rootroot00000000000000package self_test import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" mrand "math/rand/v2" "net" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) func requireIdleTimeoutError(t *testing.T, err error) { t.Helper() require.Error(t, err) var idleTimeoutErr *quic.IdleTimeoutError require.ErrorAs(t, err, &idleTimeoutErr) require.True(t, idleTimeoutErr.Timeout()) var nerr net.Error require.True(t, errors.As(err, &nerr)) require.True(t, nerr.Timeout()) } func TestHandshakeIdleTimeout(t *testing.T) { t.Run("Dial", func(t *testing.T) { testHandshakeIdleTimeout(t, quic.Dial) }) t.Run("DialEarly", func(t *testing.T) { testHandshakeIdleTimeout(t, quic.DialEarly) }) } func testHandshakeIdleTimeout(t *testing.T, dialFn func(context.Context, net.PacketConn, net.Addr, *tls.Config, *quic.Config) (*quic.Conn, error)) { synctest.Test(t, func(t *testing.T) { const handshakeIdleTimeout = 3 * time.Second clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, time.Millisecond) defer closeFn(t) errChan := make(chan error, 1) start := time.Now() go func() { _, err := dialFn( context.Background(), clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{HandshakeIdleTimeout: handshakeIdleTimeout}), ) errChan <- err }() select { case err := <-errChan: requireIdleTimeoutError(t, err) require.Equal(t, handshakeIdleTimeout, time.Since(start)) case <-time.After(5 * time.Second): t.Fatal("timeout waiting for dial error") } }) } func TestIdleTimeout(t *testing.T) { synctest.Test(t, func(t *testing.T) { const idleTimeout = 20 * time.Second var drop atomic.Bool clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, time.Millisecond, &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, ) defer closeFn(t) server, err := quic.Listen( serverPacketConn, getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) require.NoError(t, err) defer server.Close() conn, err := quic.Dial( context.Background(), clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}), ) require.NoError(t, err) serverConn, err := server.Accept(context.Background()) require.NoError(t, err) str, err := serverConn.OpenStream() require.NoError(t, err) _, err = str.Write([]byte("foobar")) require.NoError(t, err) serverStart := time.Now() strIn, err := conn.AcceptStream(context.Background()) require.NoError(t, err) strOut, err := conn.OpenStream() require.NoError(t, err) _, err = strIn.Read(make([]byte, 6)) require.NoError(t, err) clientStart := time.Now() drop.Store(true) select { case <-serverConn.Context().Done(): took := time.Since(serverStart) require.GreaterOrEqual(t, took, idleTimeout) t.Logf("server connection timed out after %s (idle timeout: %s)", took, idleTimeout) case <-time.After(2 * idleTimeout): t.Fatal("timeout waiting for idle timeout") } select { case <-conn.Context().Done(): took := time.Since(clientStart) require.GreaterOrEqual(t, took, idleTimeout) t.Logf("client connection timed out after %s (idle timeout: %s)", took, idleTimeout) case <-time.After(2 * idleTimeout): t.Fatal("timeout waiting for idle timeout") } _, err = strIn.Write([]byte("test")) requireIdleTimeoutError(t, err) _, err = strIn.Read([]byte{0}) requireIdleTimeoutError(t, err) _, err = strOut.Write([]byte("test")) requireIdleTimeoutError(t, err) _, err = strOut.Read([]byte{0}) requireIdleTimeoutError(t, err) _, err = conn.OpenStream() requireIdleTimeoutError(t, err) _, err = conn.OpenUniStream() requireIdleTimeoutError(t, err) _, err = conn.AcceptStream(context.Background()) requireIdleTimeoutError(t, err) _, err = conn.AcceptUniStream(context.Background()) requireIdleTimeoutError(t, err) }) } func TestKeepAlive(t *testing.T) { synctest.Test(t, func(t *testing.T) { const idleTimeout = 4 * time.Second var drop atomic.Bool clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, time.Millisecond, &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, ) defer closeFn(t) server, err := quic.Listen( serverPacketConn, getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial( ctx, clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: idleTimeout, KeepAlivePeriod: idleTimeout / 2, DisablePathMTUDiscovery: true, }), ) require.NoError(t, err) serverConn, err := server.Accept(ctx) require.NoError(t, err) // wait longer than the idle timeout time.Sleep(3 * idleTimeout) str, err := conn.OpenUniStream() require.NoError(t, err) _, err = str.Write([]byte("foobar")) require.NoError(t, err) // verify connection is still alive select { case <-serverConn.Context().Done(): t.Fatal("server connection closed unexpectedly") default: } // idle timeout will still kick in if PINGs are dropped drop.Store(true) time.Sleep(2 * idleTimeout) _, err = str.Write([]byte("foobar")) requireIdleTimeoutError(t, err) // can't rely on the server connection closing, since we impose a minimum idle timeout of 5s, // see https://github.com/quic-go/quic-go/issues/4751 serverConn.CloseWithError(0, "") }) } func TestTimeoutAfterInactivity(t *testing.T) { synctest.Test(t, func(t *testing.T) { const idleTimeout = 15 * time.Second clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, time.Millisecond) defer closeFn(t) server, err := quic.Listen( serverPacketConn, getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() counter, tr := newPacketTracer() conn, err := quic.Dial( ctx, clientPacketConn, server.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: idleTimeout, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tr }, DisablePathMTUDiscovery: true, }), ) require.NoError(t, err) serverConn, err := server.Accept(ctx) require.NoError(t, err) defer serverConn.CloseWithError(0, "") ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout) defer cancel() _, err = conn.AcceptStream(ctx) requireIdleTimeoutError(t, err) var lastAckElicitingPacketSentAt time.Time for _, p := range counter.getSentShortHeaderPackets() { var hasAckElicitingFrame bool for _, f := range p.frames { if _, ok := f.Frame.(qlog.AckFrame); ok { continue } hasAckElicitingFrame = true break } if hasAckElicitingFrame { lastAckElicitingPacketSentAt = p.time } } rcvdPackets := counter.getRcvdShortHeaderPackets() lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time // We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout. // This is ok since we're dealing with a lossless connection here, // and we'd expect to receive an ACK for additional other ack-eliciting packet sent. timeSinceLastAckEliciting := time.Since(lastAckElicitingPacketSentAt) timeSinceLastRcvd := time.Since(lastPacketRcvdAt) require.Equal(t, idleTimeout, max(timeSinceLastAckEliciting, timeSinceLastRcvd)) select { case <-serverConn.Context().Done(): t.Fatal("server connection closed unexpectedly") default: } }) } func TestTimeoutAfterSendingPacket(t *testing.T) { synctest.Test(t, func(t *testing.T) { const idleTimeout = 15 * time.Second var drop atomic.Bool clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, time.Millisecond, &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, ) defer closeFn(t) server, err := quic.Listen( serverPacketConn, getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) require.NoError(t, err) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial( ctx, clientPacketConn, serverPacketConn.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}), ) require.NoError(t, err) serverConn, err := server.Accept(ctx) require.NoError(t, err) serverStart := time.Now() // wait half the idle timeout, then send a packet time.Sleep(idleTimeout / 2) drop.Store(true) clientStart := time.Now() str, err := conn.OpenUniStream() require.NoError(t, err) _, err = str.Write([]byte("foobar")) require.NoError(t, err) select { case <-serverConn.Context().Done(): took := time.Since(serverStart) require.GreaterOrEqual(t, took, idleTimeout) require.Less(t, took, idleTimeout+time.Second) case <-time.After(2 * idleTimeout): t.Fatal("timeout waiting for idle timeout") } select { case <-conn.Context().Done(): took := time.Since(clientStart) require.Equal(t, took, idleTimeout) case <-time.After(2 * idleTimeout): t.Fatal("timeout waiting for idle timeout") } }) } type faultyConn struct { net.PacketConn MaxPackets int counter atomic.Int32 } func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) { n, addr, err := c.PacketConn.ReadFrom(p) counter := c.counter.Add(1) if counter <= int32(c.MaxPackets) { return n, addr, err } return 0, nil, io.ErrClosedPipe } func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) { counter := c.counter.Add(1) if counter <= int32(c.MaxPackets) { return c.PacketConn.WriteTo(p, addr) } return 0, io.ErrClosedPipe } func TestFaultyPacketConn(t *testing.T) { t.Run("client", func(t *testing.T) { testFaultyPacketConn(t, protocol.PerspectiveClient) }) t.Run("server", func(t *testing.T) { testFaultyPacketConn(t, protocol.PerspectiveServer) }) } func testFaultyPacketConn(t *testing.T, pers protocol.Perspective) { t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") synctest.Test(t, func(t *testing.T) { runServer := func(ln *quic.Listener) error { conn, err := ln.Accept(context.Background()) if err != nil { return err } str, err := conn.OpenUniStream() if err != nil { return err } defer str.Close() _, err = str.Write(PRData) return err } runClient := func(conn *quic.Conn) error { str, err := conn.AcceptUniStream(context.Background()) if err != nil { return err } data, err := io.ReadAll(str) if err != nil { return err } if !bytes.Equal(data, PRData) { return fmt.Errorf("wrong data: %q vs %q", data, PRData) } return conn.CloseWithError(0, "done") } clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, 100*time.Millisecond) defer closeFn(t) var cconn, sconn net.PacketConn = clientPacketConn, serverPacketConn maxPackets := mrand.IntN(25) // sanity check: sending PRData should generate at least 25 packets require.Greater(t, len(PRData)/1500, 25) t.Logf("blocking %s's connection after %d packets", pers, maxPackets) switch pers { case protocol.PerspectiveClient: cconn = &faultyConn{PacketConn: cconn, MaxPackets: maxPackets} case protocol.PerspectiveServer: sconn = &faultyConn{PacketConn: sconn, MaxPackets: maxPackets} } ln, err := quic.Listen( sconn, getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) require.NoError(t, err) defer ln.Close() serverErrChan := make(chan error, 1) go func() { serverErrChan <- runServer(ln) }() clientErrChan := make(chan error, 1) go func() { conn, err := quic.Dial( context.Background(), cconn, ln.Addr(), getTLSClientConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), ) if err != nil { clientErrChan <- err return } clientErrChan <- runClient(conn) }() var clientErr error select { case clientErr = <-clientErrChan: case <-time.After(time.Hour): t.Fatal("timeout waiting for client error") } require.Error(t, clientErr) if pers == protocol.PerspectiveClient { require.Contains(t, clientErr.Error(), io.ErrClosedPipe.Error()) } else { var nerr net.Error require.True(t, errors.As(clientErr, &nerr)) require.True(t, nerr.Timeout()) } select { case serverErr := <-serverErrChan: // The handshake completed on the server side. require.Error(t, serverErr) if pers == protocol.PerspectiveServer { require.Contains(t, serverErr.Error(), io.ErrClosedPipe.Error()) } else { var nerr net.Error require.True(t, errors.As(serverErr, &nerr)) require.True(t, nerr.Timeout()) } default: // The handshake didn't complete require.NoError(t, ln.Close()) select { case <-serverErrChan: case <-time.After(time.Hour): t.Fatal("timeout waiting for server to close") } } }) } quic-go-0.59.0/integrationtests/self/zero_rtt_test.go000066400000000000000000001023331513066070600227540ustar00rootroot00000000000000package self_test import ( "bytes" "context" "crypto/tls" "fmt" "io" "net" "os" "sync" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) type zeroRTTCountingRouter struct { simnet.Router counter atomic.Uint32 } var _ simnet.Router = &zeroRTTCountingRouter{} func (r *zeroRTTCountingRouter) SendPacket(p simnet.Packet) error { if containsPacketType(p.Data, protocol.PacketType0RTT) { r.counter.Add(1) } return r.Router.SendPacket(p) } func (r *zeroRTTCountingRouter) Num0RTTPackets() int { return int(r.counter.Load()) } func dialAndReceiveTicket(t *testing.T, ln *quic.EarlyListener, clientConn net.PacketConn, sessionCache tls.ClientSessionCache) (clientTLSConf *tls.Config) { t.Helper() clientTLSConf = getTLSClientConfig() puts := make(chan string, 100) cache := sessionCache if cache == nil { cache = tls.NewLRUClientSessionCache(100) } clientTLSConf.ClientSessionCache = newClientSessionCache(cache, nil, puts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() tr := &quic.Transport{Conn: clientConn} defer tr.Close() conn, err := tr.Dial(ctx, ln.Addr(), clientTLSConf, getQuicConfig(nil)) require.NoError(t, err) require.False(t, conn.ConnectionState().Used0RTT) select { case <-puts: case <-time.After(time.Second): t.Fatal("timeout waiting for session ticket") } require.NoError(t, conn.CloseWithError(0, "")) serverConn, err := ln.Accept(ctx) require.NoError(t, err) select { case <-serverConn.Context().Done(): case <-time.After(time.Second): t.Fatal("timeout waiting for connection to close") } return clientTLSConf } func transfer0RTTData( t *testing.T, ln *quic.EarlyListener, clientPacketConn net.PacketConn, clientTLSConf *tls.Config, clientConf *quic.Config, testdata []byte, // data to transfer ) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() tr := &quic.Transport{Conn: clientPacketConn} defer tr.Close() conn, err := tr.DialEarly(ctx, ln.Addr(), clientTLSConf, clientConf) require.NoError(t, err) errChan := make(chan error, 1) serverConnChan := make(chan *quic.Conn, 1) go func() { defer close(errChan) conn, err := ln.Accept(context.Background()) if err != nil { errChan <- err return } serverConnChan <- conn str, err := conn.AcceptStream(ctx) if err != nil { errChan <- err return } defer str.Close() if _, err := io.Copy(str, str); err != nil { errChan <- err return } }() str, err := conn.OpenStream() require.NoError(t, err) clientErrChan := make(chan error, 1) go func() { defer close(clientErrChan) // wait for the EOF from the server to arrive before closing the conn data, err := io.ReadAll(str) if err != nil { t.Error(err) clientErrChan <- err return } if !bytes.Equal(testdata, data) { clientErrChan <- fmt.Errorf("data mismatch") } }() _, err = str.Write(testdata) require.NoError(t, err) require.NoError(t, str.Close()) select { case <-conn.HandshakeComplete(): case <-time.After(time.Second): t.Fatal("timeout waiting for handshake to complete") } select { case err := <-clientErrChan: require.NoError(t, err) case <-time.After(time.Hour): t.Fatal("timeout waiting for client to read data") } select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout waiting for server to process data") } var serverConn *quic.Conn select { case serverConn = <-serverConnChan: case <-time.After(time.Second): t.Fatal("timeout waiting for server to process data") } require.True(t, conn.ConnectionState().Used0RTT) require.True(t, serverConn.ConnectionState().Used0RTT) conn.CloseWithError(0, "") select { case <-serverConn.Context().Done(): case <-time.After(time.Second): t.Fatal("timeout waiting for connection to close") } } func Test0RTTTransfer(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 50 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} counter, tracer := newPacketTracer() defer tr.Close() ln, err := tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) time.Sleep(time.Hour) synctest.Wait() transfer0RTTData(t, ln, clientConn, clientTLSConf, getQuicConfig(nil), PRData) num0RTT := router.Num0RTTPackets() t.Logf("sent %d 0-RTT packets", num0RTT) zeroRTTPackets := counter.getRcvd0RTTPacketNumbers() t.Logf("received %d 0-RTT packets", len(zeroRTTPackets)) require.Greater(t, num0RTT, 20) require.Contains(t, zeroRTTPackets, protocol.PacketNumber(0)) }) } func Test0RTTDisabledOnDial(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 25 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true})) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) time.Sleep(time.Hour) synctest.Wait() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.Dial(ctx, clientConn, serverConn.LocalAddr(), clientTLSConf, getQuicConfig(nil)) require.NoError(t, err) // session Resumption is enabled at the TLS layer, but not 0-RTT at the QUIC layer require.True(t, conn.ConnectionState().TLS.DidResume) require.False(t, conn.ConnectionState().Used0RTT) conn.CloseWithError(0, "") require.Zero(t, router.Num0RTTPackets()) }) } func Test0RTTWaitForHandshakeCompletion(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 50 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() counter, tracer := newPacketTracer() ln, err := tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData // now accept the second connection, and receive the 0-RTT data errChan := make(chan error, 1) firstStrDataChan := make(chan []byte, 1) secondStrDataChan := make(chan []byte, 1) go func() { defer close(errChan) conn, err := ln.Accept(context.Background()) if err != nil { errChan <- err return } str, err := conn.AcceptUniStream(context.Background()) if err != nil { errChan <- err return } data, err := io.ReadAll(str) if err != nil { errChan <- err return } firstStrDataChan <- data str, err = conn.AcceptUniStream(context.Background()) if err != nil { errChan <- err return } data, err = io.ReadAll(str) if err != nil { errChan <- err return } secondStrDataChan <- data <-conn.Context().Done() }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialEarly( ctx, clientConn, serverConn.LocalAddr(), clientTLSConf, getQuicConfig(nil), ) require.NoError(t, err) firstStr, err := conn.OpenUniStream() require.NoError(t, err) _, err = firstStr.Write(zeroRTTData) require.NoError(t, err) require.NoError(t, firstStr.Close()) // wait for the handshake to complete select { case <-conn.HandshakeComplete(): case <-time.After(time.Second): t.Fatal("handshake did not complete in time") } str, err := conn.OpenUniStream() require.NoError(t, err) _, err = str.Write(PRData) require.NoError(t, err) require.NoError(t, str.Close()) select { case data := <-firstStrDataChan: require.Equal(t, zeroRTTData, data) case <-time.After(time.Second): t.Fatal("timeout waiting for first stream data") } select { case data := <-secondStrDataChan: require.Equal(t, oneRTTData, data) case <-time.After(time.Second): t.Fatal("timeout waiting for second stream data") } conn.CloseWithError(0, "") select { case err := <-errChan: require.NoError(t, err, "server error") case <-time.After(time.Second): t.Fatal("timeout waiting for connection to close") } // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.PacketType != qlog.PacketType0RTT { continue } for _, f := range p.frames { sf, ok := f.Frame.(*qlog.StreamFrame) if !ok { continue } num0RTT++ require.Equal(t, firstStr.StreamID(), sf.StreamID) } } t.Logf("received %d STREAM frames in 0-RTT packets", num0RTT) require.NotZero(t, num0RTT) }) } func Test0RTTDataLoss(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond tlsConf := getTLSConfig() var num0RTTPackets, numDropped atomic.Uint32 router := &droppingRouter{ Drop: func(p simnet.Packet) bool { if !wire.IsLongHeaderPacket(p.Data[0]) { return false } hdr, _, _, _ := wire.ParsePacket(p.Data) if hdr.Type == protocol.PacketType0RTT { count := num0RTTPackets.Add(1) // drop 25% of the 0-RTT packets drop := count%4 == 0 if drop { numDropped.Add(1) } return drop } return false }, } clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tr := &quic.Transport{Conn: serverConn} defer tr.Close() counter, tracer := newPacketTracer() ln, err := tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) transfer0RTTData(t, ln, clientConn, clientTLSConf, getQuicConfig(nil), PRData) num0RTT := num0RTTPackets.Load() dropped := numDropped.Load() t.Logf("sent %d 0-RTT packets, dropped %d of those.", num0RTT, dropped) require.NotZero(t, num0RTT) require.NotZero(t, dropped) require.NotEmpty(t, counter.getRcvd0RTTPacketNumbers()) }) } func Test0RTTRetransmitOnRetry(t *testing.T) { t.Run("no retry", func(t *testing.T) { test0RTTRetransmitOnRetry(t, false) }) t.Run("with retry", func(t *testing.T) { test0RTTRetransmitOnRetry(t, true) }) } func test0RTTRetransmitOnRetry(t *testing.T, useRetry bool) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond tlsConf := getTLSConfig() type connIDCounter struct { connID protocol.ConnectionID bytes protocol.ByteCount } var mutex sync.Mutex var connIDToCounter []*connIDCounter countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { for len(data) > 0 { hdr, _, rest, err := wire.ParsePacket(data) if err != nil { return } data = rest if hdr.Type == protocol.PacketType0RTT { n += hdr.Length - 16 /* AEAD tag */ } } return } router := &zeroRTTCountingRouter{ Router: &callbackRouter{ Router: &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}}, OnSendPacket: func(p simnet.Packet) { if l := countZeroRTTBytes(p.Data); l > 0 { mutex.Lock() defer mutex.Unlock() connID, err := wire.ParseConnectionID(p.Data, 0) if err != nil { panic("failed to parse connection ID") } var found bool for _, c := range connIDToCounter { if c.connID == connID { c.bytes += l found = true break } } if !found { connIDToCounter = append(connIDToCounter, &connIDCounter{connID: connID, bytes: l}) } } }, }, } clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tr := &quic.Transport{ Conn: serverConn, VerifySourceAddress: func(net.Addr) bool { return useRetry }, } defer tr.Close() counter, tracer := newPacketTracer() ln, err := tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.Empty(t, connIDToCounter) transfer0RTTData(t, ln, clientConn, clientTLSConf, getQuicConfig(nil), GeneratePRData(5000)) // ~5 packets mutex.Lock() defer mutex.Unlock() if !useRetry { require.Len(t, connIDToCounter, 1) return } require.Len(t, connIDToCounter, 2) require.InDelta(t, 5000+100 /* framing overhead */, int(connIDToCounter[0].bytes), 100) // the FIN bit might be sent extra require.InDelta(t, int(connIDToCounter[0].bytes), int(connIDToCounter[1].bytes), 20) zeroRTTPackets := counter.getRcvd0RTTPacketNumbers() require.GreaterOrEqual(t, len(zeroRTTPackets), 5) require.GreaterOrEqual(t, zeroRTTPackets[0], protocol.PacketNumber(5)) }) } func Test0RTTWithIncreasedStreamLimit(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true, MaxIncomingUniStreams: 1})) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.Zero(t, router.Num0RTTPackets()) require.NoError(t, ln.Close()) time.Sleep(time.Hour) synctest.Wait() ln, err = tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true, MaxIncomingUniStreams: 2})) require.NoError(t, err) defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialEarly( ctx, clientConn, ln.Addr(), clientTLSConf, getQuicConfig(nil), ) require.NoError(t, err) require.False(t, conn.ConnectionState().TLS.HandshakeComplete) str, err := conn.OpenUniStream() require.NoError(t, err) _, err = str.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, str.Close()) // the client remembers the old limit and refuses to open a new stream _, err = conn.OpenUniStream() require.ErrorIs(t, err, &quic.StreamLimitReachedError{}) // after handshake completion, the new limit applies select { case <-conn.HandshakeComplete(): case <-time.After(time.Second): t.Fatal("handshake did not complete in time") } _, err = conn.OpenUniStream() require.NoError(t, err) require.True(t, conn.ConnectionState().Used0RTT) require.NoError(t, conn.CloseWithError(0, "")) require.NotZero(t, router.Num0RTTPackets()) }) } func check0RTTRejected(t *testing.T, ln *quic.EarlyListener, clientPacketConn net.PacketConn, addr net.Addr, conf *tls.Config, sendData bool, ) (clientConn, serverConn *quic.Conn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialEarly(ctx, clientPacketConn, addr, conf, getQuicConfig(nil)) require.NoError(t, err) require.False(t, conn.ConnectionState().TLS.HandshakeComplete) if sendData { str, err := conn.OpenUniStream() require.NoError(t, err) _, err = str.Write(make([]byte, 3000)) require.NoError(t, err) require.NoError(t, str.Close()) } select { case <-conn.HandshakeComplete(): case <-time.After(time.Second): t.Fatal("handshake did not complete in time") } require.False(t, conn.ConnectionState().Used0RTT) // make sure the server doesn't process the data ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() serverConn, err = ln.Accept(ctx) require.NoError(t, err) require.False(t, serverConn.ConnectionState().Used0RTT) if sendData { _, err = serverConn.AcceptUniStream(ctx) require.Equal(t, context.DeadlineExceeded, err) } ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() nextConn, err := conn.NextConnection(ctx) require.NoError(t, err) require.True(t, nextConn.ConnectionState().TLS.HandshakeComplete) require.False(t, nextConn.ConnectionState().Used0RTT) return nextConn, serverConn } func Test0RTTRejectedOnStreamLimitDecrease(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond const ( maxBidiStreams = 42 maxUniStreams = 10 newMaxBidiStreams = maxBidiStreams - 1 newMaxUniStreams = maxUniStreams - 1 ) router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, MaxIncomingStreams: maxBidiStreams, MaxIncomingUniStreams: maxUniStreams, }), ) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.NoError(t, ln.Close()) time.Sleep(time.Hour) synctest.Wait() counter, tracer := newPacketTracer() ln, err = tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, MaxIncomingStreams: newMaxBidiStreams, MaxIncomingUniStreams: newMaxUniStreams, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() conn, sconn := check0RTTRejected(t, ln, clientConn, ln.Addr(), clientTLSConf, true) defer conn.CloseWithError(0, "") // It should now be possible to open new bidirectional streams up to the new limit... for range newMaxBidiStreams { _, err = conn.OpenStream() require.NoError(t, err) } // ... but not beyond it. _, err = conn.OpenStream() require.ErrorIs(t, err, &quic.StreamLimitReachedError{}) // It should now be possible to open new unidirectional streams up to the new limit... for range newMaxUniStreams { _, err = conn.OpenUniStream() require.NoError(t, err) } // ... but not beyond it. _, err = conn.OpenUniStream() require.ErrorIs(t, err, &quic.StreamLimitReachedError{}) sconn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. n := router.Num0RTTPackets() t.Logf("sent %d 0-RTT packets", n) require.NotZero(t, n) require.Empty(t, counter.getRcvd0RTTPacketNumbers()) }) } func Test0RTTRejectedOnConnectionWindowDecrease(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond const ( connFlowControlWindow = 100 newConnFlowControlWindow = connFlowControlWindow - 1 ) router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, InitialConnectionReceiveWindow: connFlowControlWindow, }), ) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.NoError(t, ln.Close()) time.Sleep(time.Hour) synctest.Wait() ln, err = tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, InitialConnectionReceiveWindow: newConnFlowControlWindow, }), ) require.NoError(t, err) conn, sconn := check0RTTRejected(t, ln, clientConn, ln.Addr(), clientTLSConf, false) defer conn.CloseWithError(0, "") defer sconn.CloseWithError(0, "") str, err := conn.OpenStream() require.NoError(t, err) str.SetWriteDeadline(time.Now().Add(time.Second)) n, err := str.Write(make([]byte, 2000)) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Equal(t, newConnFlowControlWindow, n) // make sure that only 99 bytes were received ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() serverStr, err := sconn.AcceptStream(ctx) require.NoError(t, err) serverStr.SetReadDeadline(time.Now().Add(time.Second)) n, err = io.ReadFull(serverStr, make([]byte, newConnFlowControlWindow)) require.NoError(t, err) require.Equal(t, newConnFlowControlWindow, n) _, err = serverStr.Read([]byte{0}) require.ErrorIs(t, err, os.ErrDeadlineExceeded) }) } func Test0RTTRejectedOnALPNChanged(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true})) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.NoError(t, ln.Close()) time.Sleep(time.Hour) synctest.Wait() // switch to different ALPN on the server side tlsConf.NextProtos = []string{"new-alpn"} // Append to the client's ALPN. // crypto/tls will attempt to resume with the ALPN from the original connection clientTLSConf.NextProtos = append(clientTLSConf.NextProtos, "new-alpn") counter, tracer := newPacketTracer() ln, err = tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() conn, sconn := check0RTTRejected(t, ln, clientConn, ln.Addr(), clientTLSConf, true) defer conn.CloseWithError(0, "") require.Equal(t, "new-alpn", conn.ConnectionState().TLS.NegotiatedProtocol) sconn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := router.Num0RTTPackets() t.Logf("Sent %d 0-RTT packets.", num0RTT) require.NotZero(t, num0RTT) require.Empty(t, counter.getRcvd0RTTPacketNumbers()) }) } func Test0RTTRejectedWhenDisabled(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true})) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.NoError(t, ln.Close()) time.Sleep(time.Hour) synctest.Wait() counter, tracer := newPacketTracer() ln, err = tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: false, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() conn, sconn := check0RTTRejected(t, ln, clientConn, ln.Addr(), clientTLSConf, true) defer conn.CloseWithError(0, "") sconn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := router.Num0RTTPackets() t.Logf("Sent %d 0-RTT packets.", num0RTT) require.NotZero(t, num0RTT) require.Empty(t, counter.getRcvd0RTTPacketNumbers()) }) } func Test0RTTRejectedOnDatagramsDisabled(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tlsConf := getTLSConfig() tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true, EnableDatagrams: true})) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) require.NoError(t, ln.Close()) time.Sleep(time.Hour) synctest.Wait() counter, tracer := newPacketTracer() ln, err = tr.ListenEarly( tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, EnableDatagrams: false, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() conn, sconn := check0RTTRejected(t, ln, clientConn, ln.Addr(), clientTLSConf, true) defer conn.CloseWithError(0, "") require.False(t, conn.ConnectionState().SupportsDatagrams.Remote) require.False(t, conn.ConnectionState().SupportsDatagrams.Local) sconn.CloseWithError(0, "") // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := router.Num0RTTPackets() t.Logf("Sent %d 0-RTT packets.", num0RTT) require.NotZero(t, num0RTT) require.Empty(t, counter.getRcvd0RTTPacketNumbers()) }) } type metadataClientSessionCache struct { toAdd []byte restored func([]byte) cache tls.ClientSessionCache } func (m metadataClientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { session, ok := m.cache.Get(key) if !ok || session == nil { return session, ok } ticket, state, err := session.ResumptionState() if err != nil { panic("failed to get resumption state: " + err.Error()) } if len(state.Extra) != 2 { // ours, and the quic-go's panic("expected 2 state entries" + fmt.Sprintf("%v", state.Extra)) } m.restored(state.Extra[1]) // as of Go 1.23, this function never returns an error session, err = tls.NewResumptionState(ticket, state) if err != nil { panic("failed to create resumption state: " + err.Error()) } return session, true } func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionState) { ticket, state, err := session.ResumptionState() if err != nil { panic("failed to get resumption state: " + err.Error()) } state.Extra = append(state.Extra, m.toAdd) session, err = tls.NewResumptionState(ticket, state) if err != nil { panic("failed to create resumption state: " + err.Error()) } m.cache.Put(key, session) } func Test0RTTWithSessionTicketData(t *testing.T) { const rtt = 5 * time.Millisecond t.Run("server", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { tlsConf := getTLSConfig() tlsConf.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) { ss.Extra = append(ss.Extra, []byte("foobar")) return tlsConf.EncryptTicket(cs, ss) } router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(tlsConf, getQuicConfig(&quic.Config{Allow0RTT: true})) require.NoError(t, err) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) stateChan := make(chan *tls.SessionState, 1) tlsConf.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) { state, err := tlsConf.DecryptTicket(identity, cs) if err != nil { panic("failed to decrypt ticket") } stateChan <- state return state, nil } transfer0RTTData(t, ln, clientConn, clientTLSConf, getQuicConfig(nil), PRData) select { case state := <-stateChan: require.Len(t, state.Extra, 2) require.Equal(t, []byte("foobar"), state.Extra[1]) case <-time.After(time.Second): t.Fatal("timed out waiting for session state") } }) }) t.Run("client", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tr := &quic.Transport{Conn: serverConn} defer tr.Close() ln, err := tr.ListenEarly(getTLSConfig(), getQuicConfig(&quic.Config{Allow0RTT: true})) require.NoError(t, err) defer ln.Close() restoreChan := make(chan []byte, 1) clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, &metadataClientSessionCache{ toAdd: []byte("foobar"), restored: func(b []byte) { restoreChan <- b }, cache: tls.NewLRUClientSessionCache(100), }, ) transfer0RTTData(t, ln, clientConn, clientTLSConf, getQuicConfig(nil), PRData) select { case b := <-restoreChan: require.Equal(t, []byte("foobar"), b) case <-time.After(time.Second): t.Fatal("timed out waiting for session state") } }) }) } func Test0RTTPacketQueueing(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond n := &simnet.Simnet{Router: &simnet.PerfectRouter{}} serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002} settings := simnet.NodeBiDiLinkSettings{ LatencyFunc: func(p simnet.Packet) time.Duration { if p.To.String() == serverAddr.String() { if wire.IsLongHeaderPacket(p.Data[0]) { hdr, _, _, err := wire.ParsePacket(p.Data) if err == nil && hdr.Type == protocol.PacketTypeInitial { return rtt * 3 / 2 } } } return rtt / 2 }, } clientConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}, settings) serverConn := n.NewEndpoint(serverAddr, settings) require.NoError(t, n.Start()) defer func() { require.NoError(t, clientConn.Close()) require.NoError(t, serverConn.Close()) require.NoError(t, n.Close()) }() tr := &quic.Transport{Conn: serverConn} defer tr.Close() counter, tracer := newPacketTracer() ln, err := tr.ListenEarly( getTLSConfig(), getQuicConfig(&quic.Config{ Allow0RTT: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) data := GeneratePRData(5000) // ~5 packets transfer0RTTData(t, ln, clientConn, clientTLSConf, getQuicConfig(nil), data) require.Equal(t, qlog.PacketTypeInitial, counter.getRcvdLongHeaderPackets()[0].hdr.PacketType) zeroRTTPackets := counter.getRcvd0RTTPacketNumbers() require.GreaterOrEqual(t, len(zeroRTTPackets), 5) // make sure the data wasn't retransmitted var dataSent protocol.ByteCount for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.Frame.(*qlog.StreamFrame); ok { dataSent += protocol.ByteCount(sf.Length) } } } for _, p := range counter.getRcvdShortHeaderPackets() { for _, f := range p.frames { if sf, ok := f.Frame.(*qlog.StreamFrame); ok { dataSent += protocol.ByteCount(sf.Length) } } } require.Less(t, int(dataSent), 6000) require.Equal(t, protocol.PacketNumber(0), zeroRTTPackets[0]) }) } func Test0RTTDatagrams(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 5 * time.Millisecond router := &zeroRTTCountingRouter{Router: &simnet.PerfectRouter{}} clientConn, serverConn, closeFn := newSimnetLinkWithRouter(t, rtt, router) defer closeFn(t) tr := &quic.Transport{Conn: serverConn} defer tr.Close() counter, tracer := newPacketTracer() ln, err := tr.ListenEarly( getTLSConfig(), getQuicConfig(&quic.Config{ Allow0RTT: true, EnableDatagrams: true, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer }, }), ) require.NoError(t, err) defer ln.Close() clientTLSConf := dialAndReceiveTicket(t, ln, clientConn, nil) msg := GeneratePRData(100) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := quic.DialEarly(ctx, clientConn, ln.Addr(), clientTLSConf, getQuicConfig(&quic.Config{EnableDatagrams: true}), ) require.NoError(t, err) defer conn.CloseWithError(0, "") require.True(t, conn.ConnectionState().SupportsDatagrams.Remote) require.True(t, conn.ConnectionState().SupportsDatagrams.Local) require.NoError(t, conn.SendDatagram(msg)) select { case <-conn.HandshakeComplete(): case <-time.After(time.Second): t.Fatal("handshake did not complete in time") } sconn, err := ln.Accept(ctx) require.NoError(t, err) rcvdMsg, err := sconn.ReceiveDatagram(ctx) require.NoError(t, err) require.True(t, sconn.ConnectionState().Used0RTT) require.Equal(t, msg, rcvdMsg) num0RTT := router.Num0RTTPackets() t.Logf("sent %d 0-RTT packets", num0RTT) require.NotZero(t, num0RTT) sconn.CloseWithError(0, "") require.Len(t, counter.getRcvd0RTTPacketNumbers(), 1) }) } quic-go-0.59.0/integrationtests/tools/000077500000000000000000000000001513066070600177235ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/tools/crypto.go000066400000000000000000000071341513066070600215770ustar00rootroot00000000000000package tools import ( "crypto" "crypto/ed25519" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "math/big" "net" "time" ) const ALPN = "quic-go integration tests" // use a very long validity period to cover the synthetic clock used in synctest var ( notBefore = time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC) notAfter = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) ) func GenerateCA() (*x509.Certificate, crypto.PrivateKey, error) { certTempl := &x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: pkix.Name{}, NotBefore: notBefore, NotAfter: notAfter, IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, err } caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, pub, priv) if err != nil { return nil, nil, err } ca, err := x509.ParseCertificate(caBytes) if err != nil { return nil, nil, err } return ca, priv, nil } func GenerateLeafCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*x509.Certificate, crypto.PrivateKey, error) { certTempl := &x509.Certificate{ SerialNumber: big.NewInt(1), DNSNames: []string{"localhost"}, IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, NotBefore: notBefore, NotAfter: notAfter, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, err } certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, pub, caPriv) if err != nil { return nil, nil, err } cert, err := x509.ParseCertificate(certBytes) if err != nil { return nil, nil, err } return cert, priv, nil } // GenerateTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain. // The Root CA used is the same as for the config returned from getTLSConfig(). func GenerateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey crypto.PrivateKey) (*tls.Config, error) { const chainLen = 16 certTempl := &x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: pkix.Name{}, NotBefore: notBefore, NotAfter: notAfter, IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } lastCA := ca lastCAPrivKey := caPrivateKey _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } certs := make([]*x509.Certificate, chainLen) for i := range chainLen { caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, priv.Public(), lastCAPrivKey) if err != nil { return nil, err } ca, err := x509.ParseCertificate(caBytes) if err != nil { return nil, err } certs[i] = ca lastCA = ca lastCAPrivKey = priv } leafCert, leafPrivateKey, err := GenerateLeafCert(lastCA, lastCAPrivKey) if err != nil { return nil, err } rawCerts := make([][]byte, chainLen+1) for i, cert := range certs { rawCerts[chainLen-i] = cert.Raw } rawCerts[0] = leafCert.Raw return &tls.Config{ Certificates: []tls.Certificate{{ Certificate: rawCerts, PrivateKey: leafPrivateKey, }}, NextProtos: []string{ALPN}, }, nil } quic-go-0.59.0/integrationtests/tools/crypto_test.go000066400000000000000000000042471513066070600226400ustar00rootroot00000000000000package tools import ( "crypto/tls" "crypto/x509" "io" "net" "testing" "github.com/stretchr/testify/require" ) type countingConn struct { net.Conn BytesReceived int } func (c *countingConn) Read(b []byte) (int, error) { n, err := c.Conn.Read(b) c.BytesReceived += n return n, err } func TestGenerateTLSConfig(t *testing.T) { ca, caPriv, err := GenerateCA() require.NoError(t, err) certPool := x509.NewCertPool() certPool.AddCert(ca) clientConf := &tls.Config{ ServerName: "localhost", RootCAs: certPool, } t.Run("short chain", func(t *testing.T) { leaf, leafPriv, err := GenerateLeafCert(ca, caPriv) require.NoError(t, err) serverConf := &tls.Config{ Certificates: []tls.Certificate{{ Certificate: [][]byte{leaf.Raw}, PrivateKey: leafPriv, }}, } bytesReceived := testGenerateTLSConfig(t, serverConf, clientConf) t.Logf("bytes received: %d", bytesReceived) require.Less(t, bytesReceived, 2000) }) t.Run("long chain", func(t *testing.T) { serverConf, err := GenerateTLSConfigWithLongCertChain(ca, caPriv) require.NoError(t, err) bytesReceived := testGenerateTLSConfig(t, serverConf, clientConf) t.Logf("bytes received: %d", bytesReceived) require.Greater(t, bytesReceived, 5000) }) } func testGenerateTLSConfig(t *testing.T, serverConf, clientConf *tls.Config) int { ln, err := tls.Listen("tcp", "127.0.0.1:0", serverConf) require.NoError(t, err) defer ln.Close() type result struct { err error msg string } resultChan := make(chan result, 1) go func() { conn, err := ln.Accept() if err != nil { resultChan <- result{err: err} return } defer conn.Close() msg, err := io.ReadAll(conn) resultChan <- result{err: err, msg: string(msg)} }() tcpConn, err := net.Dial("tcp", ln.Addr().String()) require.NoError(t, err) defer tcpConn.Close() countingConn := &countingConn{Conn: tcpConn} tlsConn := tls.Client(countingConn, clientConf) require.NoError(t, tlsConn.Handshake()) _, err = tlsConn.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, tlsConn.Close()) res := <-resultChan require.NoError(t, res.err) require.Equal(t, "foobar", res.msg) return countingConn.BytesReceived } quic-go-0.59.0/integrationtests/tools/israce/000077500000000000000000000000001513066070600211715ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/tools/israce/norace.go000066400000000000000000000001541513066070600227670ustar00rootroot00000000000000//go:build !race package israce // Enabled reports if the race detector is enabled. const Enabled = false quic-go-0.59.0/integrationtests/tools/israce/race.go000066400000000000000000000001521513066070600224300ustar00rootroot00000000000000//go:build race package israce // Enabled reports if the race detector is enabled. const Enabled = true quic-go-0.59.0/integrationtests/tools/proxy/000077500000000000000000000000001513066070600211045ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/tools/proxy/proxy.go000066400000000000000000000223231513066070600226160ustar00rootroot00000000000000package quicproxy import ( "errors" "fmt" "net" "os" "slices" "sync" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) // Connection is a UDP connection type connection struct { ClientAddr *net.UDPAddr // Address of the client ServerAddr *net.UDPAddr // Address of the server mx sync.Mutex ServerConn *net.UDPConn // UDP connection to server incomingPackets chan packetEntry Incoming *queue Outgoing *queue } func (c *connection) queuePacket(t monotime.Time, b []byte) { c.incomingPackets <- packetEntry{Time: t, Raw: b} } func (c *connection) SwitchConn(conn *net.UDPConn) { c.mx.Lock() defer c.mx.Unlock() old := c.ServerConn old.SetReadDeadline(time.Now()) c.ServerConn = conn } func (c *connection) GetServerConn() *net.UDPConn { c.mx.Lock() defer c.mx.Unlock() return c.ServerConn } // Direction is the direction a packet is sent. type Direction int const ( // DirectionIncoming is the direction from the client to the server. DirectionIncoming Direction = iota // DirectionOutgoing is the direction from the server to the client. DirectionOutgoing // DirectionBoth is both incoming and outgoing DirectionBoth ) type packetEntry struct { Time monotime.Time Raw []byte } type queue struct { sync.Mutex timer *time.Timer Packets []packetEntry // sorted by the packetEntry.Time } func newQueue() *queue { // there's no way to initialize a time.Timer that's not running return &queue{timer: time.NewTimer(24 * time.Hour)} } func (q *queue) Add(e packetEntry) { q.Lock() defer q.Unlock() if len(q.Packets) == 0 { q.Packets = append(q.Packets, e) q.timer.Reset(monotime.Until(e.Time)) return } // The packets slice is sorted by the packetEntry.Time. // We only need to insert the packet at the correct position. idx := slices.IndexFunc(q.Packets, func(p packetEntry) bool { return p.Time.After(e.Time) }) if idx == -1 { q.Packets = append(q.Packets, e) } else { q.Packets = slices.Insert(q.Packets, idx, e) } if idx == 0 { q.timer.Reset(monotime.Until(q.Packets[0].Time)) } } func (q *queue) Get() []byte { q.Lock() raw := q.Packets[0].Raw q.Packets = q.Packets[1:] if len(q.Packets) > 0 { q.timer.Reset(monotime.Until(q.Packets[0].Time)) } q.Unlock() return raw } func (q *queue) Timer() <-chan time.Time { return q.timer.C } func (q *queue) Close() { q.timer.Stop() } func (d Direction) String() string { switch d { case DirectionIncoming: return "Incoming" case DirectionOutgoing: return "Outgoing" case DirectionBoth: return "both" default: panic("unknown direction") } } // Is says if one direction matches another direction. // For example, incoming matches both incoming and both, but not outgoing. func (d Direction) Is(dir Direction) bool { if d == DirectionBoth || dir == DirectionBoth { return true } return d == dir } // DropCallback is a callback that determines which packet gets dropped. type DropCallback func(dir Direction, from, to net.Addr, packet []byte) bool // DelayCallback is a callback that determines how much delay to apply to a packet. type DelayCallback func(dir Direction, from, to net.Addr, packet []byte) time.Duration // Proxy is a QUIC proxy that can drop and delay packets. type Proxy struct { // Conn is the UDP socket that the proxy listens on for incoming packets from clients. Conn *net.UDPConn // ServerAddr is the address of the server that the proxy forwards packets to. ServerAddr *net.UDPAddr // DropPacket is a callback that determines which packet gets dropped. DropPacket DropCallback // DelayPacket is a callback that determines how much delay to apply to a packet. DelayPacket DelayCallback closeChan chan struct{} logger utils.Logger // mapping from client addresses (as host:port) to connection mutex sync.Mutex clientDict map[string]*connection } func (p *Proxy) Start() error { p.clientDict = make(map[string]*connection) p.closeChan = make(chan struct{}) p.logger = utils.DefaultLogger.WithPrefix("proxy") if err := p.Conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { return err } if err := p.Conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { return err } p.logger.Debugf("Starting UDP Proxy %s <-> %s", p.Conn.LocalAddr(), p.ServerAddr) go p.runProxy() return nil } // SwitchConn switches the connection for a client, // identified the address that the client is sending from. func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error { if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { return err } if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { return err } p.mutex.Lock() defer p.mutex.Unlock() c, ok := p.clientDict[clientAddr.String()] if !ok { return fmt.Errorf("client %s not found", clientAddr) } c.SwitchConn(conn) return nil } // Close stops the UDP Proxy func (p *Proxy) Close() error { p.mutex.Lock() defer p.mutex.Unlock() close(p.closeChan) for _, c := range p.clientDict { if err := c.GetServerConn().Close(); err != nil { return err } c.Incoming.Close() c.Outgoing.Close() } return nil } // LocalAddr is the address the proxy is listening on. func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() } func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) if err != nil { return nil, err } if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { return nil, err } if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { return nil, err } return &connection{ ClientAddr: cliAddr, ServerAddr: p.ServerAddr, incomingPackets: make(chan packetEntry, 10), Incoming: newQueue(), Outgoing: newQueue(), ServerConn: conn, }, nil } // runProxy listens on the proxy address and handles incoming packets. func (p *Proxy) runProxy() error { for { buffer := make([]byte, protocol.MaxPacketBufferSize) n, cliaddr, err := p.Conn.ReadFromUDP(buffer) if err != nil { return err } raw := buffer[:n] p.mutex.Lock() conn, ok := p.clientDict[cliaddr.String()] if !ok { conn, err = p.newConnection(cliaddr) if err != nil { p.mutex.Unlock() return err } p.clientDict[cliaddr.String()] = conn go p.runIncomingConnection(conn) go p.runOutgoingConnection(conn) } p.mutex.Unlock() if p.DropPacket != nil && p.DropPacket(DirectionIncoming, cliaddr, conn.ServerAddr, raw) { if p.logger.Debug() { p.logger.Debugf("dropping incoming packet(%d bytes)", n) } continue } var delay time.Duration if p.DelayPacket != nil { delay = p.DelayPacket(DirectionIncoming, cliaddr, conn.ServerAddr, raw) } if delay == 0 { if p.logger.Debug() { p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr) } if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil { return err } } else { now := monotime.Now() if p.logger.Debug() { p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay) } conn.queuePacket(now.Add(delay), raw) } } } // runConnection handles packets from server to a single client func (p *Proxy) runOutgoingConnection(conn *connection) error { outgoingPackets := make(chan packetEntry, 10) go func() { for { buffer := make([]byte, protocol.MaxPacketBufferSize) n, addr, err := conn.GetServerConn().ReadFrom(buffer) if err != nil { // when the connection is switched out, we set a deadline on the old connection, // in order to return it immediately if errors.Is(err, os.ErrDeadlineExceeded) { continue } return } raw := buffer[0:n] if p.DropPacket != nil && p.DropPacket(DirectionOutgoing, addr, conn.ClientAddr, raw) { if p.logger.Debug() { p.logger.Debugf("dropping outgoing packet(%d bytes)", n) } continue } var delay time.Duration if p.DelayPacket != nil { delay = p.DelayPacket(DirectionOutgoing, addr, conn.ClientAddr, raw) } if delay == 0 { if p.logger.Debug() { p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr) } if _, err := p.Conn.WriteToUDP(raw, conn.ClientAddr); err != nil { return } } else { now := monotime.Now() if p.logger.Debug() { p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay) } outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw} } } }() for { select { case <-p.closeChan: return nil case e := <-outgoingPackets: conn.Outgoing.Add(e) case <-conn.Outgoing.Timer(): if _, err := p.Conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil { return err } } } } func (p *Proxy) runIncomingConnection(conn *connection) error { for { select { case <-p.closeChan: return nil case e := <-conn.incomingPackets: // Send the packet to the server conn.Incoming.Add(e) case <-conn.Incoming.Timer(): if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil { return err } } } } quic-go-0.59.0/integrationtests/tools/proxy/proxy_test.go000066400000000000000000000345301513066070600236600ustar00rootroot00000000000000package quicproxy import ( "net" "strconv" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestPacketQueue(t *testing.T) { q := newQueue() getPackets := func() []string { packets := make([]string, 0, len(q.Packets)) for _, p := range q.Packets { packets = append(packets, string(p.Raw)) } return packets } require.Empty(t, getPackets()) now := monotime.Now() q.Add(packetEntry{Time: now, Raw: []byte("p3")}) require.Equal(t, []string{"p3"}, getPackets()) q.Add(packetEntry{Time: now.Add(time.Second), Raw: []byte("p4")}) require.Equal(t, []string{"p3", "p4"}, getPackets()) q.Add(packetEntry{Time: now.Add(-time.Second), Raw: []byte("p1")}) require.Equal(t, []string{"p1", "p3", "p4"}, getPackets()) q.Add(packetEntry{Time: now.Add(time.Second), Raw: []byte("p5")}) require.Equal(t, []string{"p1", "p3", "p4", "p5"}, getPackets()) q.Add(packetEntry{Time: now.Add(-time.Second), Raw: []byte("p2")}) require.Equal(t, []string{"p1", "p2", "p3", "p4", "p5"}, getPackets()) } func newUPDConnLocalhost(t testing.TB) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) return conn } func makePacket(t *testing.T, p protocol.PacketNumber, payload []byte) []byte { t.Helper() hdr := wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, Version: protocol.Version1, Length: 4 + protocol.ByteCount(len(payload)), DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), }, PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen4, } b, err := hdr.Append(nil, protocol.Version1) require.NoError(t, err) b = append(b, payload...) return b } func readPacketNumber(t *testing.T, b []byte) protocol.PacketNumber { t.Helper() hdr, data, _, err := wire.ParsePacket(b) require.NoError(t, err) require.Equal(t, protocol.PacketTypeInitial, hdr.Type) extHdr, err := hdr.ParseExtended(data) require.NoError(t, err) return extHdr.PacketNumber } // Set up a dumb UDP server. // In production this would be a QUIC server. func runServer(t *testing.T) (*net.UDPAddr, chan []byte) { done := make(chan struct{}) t.Cleanup(func() { select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } }) serverConn := newUPDConnLocalhost(t) serverReceivedPackets := make(chan []byte, 100) go func() { defer close(done) for { buf := make([]byte, protocol.MaxPacketBufferSize) // the ReadFromUDP will error as soon as the UDP conn is closed n, addr, err := serverConn.ReadFromUDP(buf) if err != nil { return } serverReceivedPackets <- buf[:n] // echo the packet if _, err := serverConn.WriteToUDP(buf[:n], addr); err != nil { return } } }() return serverConn.LocalAddr().(*net.UDPAddr), serverReceivedPackets } func TestProxyingBackAndForth(t *testing.T) { serverAddr, _ := runServer(t) proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) // send the first packet _, err = clientConn.Write(makePacket(t, 1, []byte("foobar"))) require.NoError(t, err) // send the second packet _, err = clientConn.Write(makePacket(t, 2, []byte("decafbad"))) require.NoError(t, err) buf := make([]byte, 1024) n, err := clientConn.Read(buf) require.NoError(t, err) require.Contains(t, string(buf[:n]), "foobar") n, err = clientConn.Read(buf) require.NoError(t, err) require.Contains(t, string(buf[:n]), "decafbad") } func TestDropIncomingPackets(t *testing.T) { const numPackets = 6 serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 var fromAddr, toAddr atomic.Pointer[net.Addr] proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DropPacket: func(d Direction, from, to net.Addr, _ []byte) bool { if d != DirectionIncoming { return false } fromAddr.Store(&from) toAddr.Store(&to) return counter.Add(1)%2 == 1 }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 0; i < numPackets/2; i++ { select { case <-serverReceivedPackets: case <-time.After(time.Second): t.Fatalf("timeout") } } select { case <-serverReceivedPackets: t.Fatalf("received unexpected packet") case <-time.After(100 * time.Millisecond): } require.Equal(t, *fromAddr.Load(), clientConn.LocalAddr()) require.Equal(t, *toAddr.Load(), serverAddr) } func TestDropOutgoingPackets(t *testing.T) { const numPackets = 6 serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 var fromAddr, toAddr atomic.Pointer[net.Addr] proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DropPacket: func(d Direction, from, to net.Addr, _ []byte) bool { if d != DirectionOutgoing { return false } fromAddr.Store(&from) toAddr.Store(&to) return counter.Add(1)%2 == 1 }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) clientReceivedPackets := make(chan struct{}, numPackets) // receive the packets echoed by the server on client side go func() { for { buf := make([]byte, protocol.MaxPacketBufferSize) if _, _, err := clientConn.ReadFromUDP(buf); err != nil { return } clientReceivedPackets <- struct{}{} } }() for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 0; i < numPackets/2; i++ { select { case <-clientReceivedPackets: case <-time.After(time.Second): t.Fatalf("timeout") } } select { case <-clientReceivedPackets: t.Fatalf("received unexpected packet") case <-time.After(100 * time.Millisecond): } require.Len(t, serverReceivedPackets, numPackets) require.Equal(t, *fromAddr.Load(), serverAddr) require.Equal(t, *toAddr.Load(), clientConn.LocalAddr()) } func TestDelayIncomingPackets(t *testing.T) { const numPackets = 3 const delay = 200 * time.Millisecond serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _, _ net.Addr, _ []byte) time.Duration { // delay packet 1 by 200 ms // delay packet 2 by 400 ms // ... if d == DirectionOutgoing { return 0 } p := counter.Add(1) return time.Duration(p) * delay }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) start := time.Now() for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 1; i <= numPackets; i++ { select { case data := <-serverReceivedPackets: require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2) require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, data)) case <-time.After(time.Second): t.Fatalf("timeout waiting for packet %d", i) } } } func TestPacketReordering(t *testing.T) { const delay = 200 * time.Millisecond expectDelay := func(startTime time.Time, numRTTs int) { expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * delay) now := time.Now() require.True(t, now.After(expectedReceiveTime) || now.Equal(expectedReceiveTime)) require.True(t, now.Before(expectedReceiveTime.Add(delay/2))) } serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _, _ net.Addr, _ []byte) time.Duration { // delay packet 1 by 600 ms // delay packet 2 by 400 ms // delay packet 3 by 200 ms if d == DirectionOutgoing { return 0 } p := counter.Add(1) return 600*time.Millisecond - time.Duration(p-1)*delay }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) // send 3 packets start := time.Now() for i := 1; i <= 3; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } for i := 1; i <= 3; i++ { select { case packet := <-serverReceivedPackets: expectDelay(start, i) expectedPacketNumber := protocol.PacketNumber(4 - i) // 3, 2, 1 in reverse order require.Equal(t, expectedPacketNumber, readPacketNumber(t, packet)) case <-time.After(time.Second): t.Fatalf("timeout waiting for packet %d", i) } } } func TestConstantDelay(t *testing.T) { // no reordering expected here serverAddr, serverReceivedPackets := runServer(t) proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _, _ net.Addr, _ []byte) time.Duration { if d == DirectionOutgoing { return 0 } return 100 * time.Millisecond }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) // send 100 packets for i := 0; i < 100; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } require.Eventually(t, func() bool { return len(serverReceivedPackets) == 100 }, 5*time.Second, 10*time.Millisecond) timeout := time.After(5 * time.Second) for i := 0; i < 100; i++ { select { case packet := <-serverReceivedPackets: require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet)) case <-timeout: t.Fatalf("timeout waiting for packet %d", i) } } } func TestDelayOutgoingPackets(t *testing.T) { const numPackets = 3 const delay = 200 * time.Millisecond serverAddr, serverReceivedPackets := runServer(t) var counter atomic.Int32 proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverAddr, DelayPacket: func(d Direction, _, _ net.Addr, _ []byte) time.Duration { // delay packet 1 by 200 ms // delay packet 2 by 400 ms // ... if d == DirectionIncoming { return 0 } p := counter.Add(1) return time.Duration(p) * delay }, } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) clientReceivedPackets := make(chan []byte, numPackets) // receive the packets echoed by the server on client side go func() { for { buf := make([]byte, protocol.MaxPacketBufferSize) n, _, err := clientConn.ReadFromUDP(buf) if err != nil { return } clientReceivedPackets <- buf[:n] } }() start := time.Now() for i := 1; i <= numPackets; i++ { _, err := clientConn.Write(makePacket(t, protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i)))) require.NoError(t, err) } // the packets should have arrived immediately at the server for i := 0; i < numPackets; i++ { select { case <-serverReceivedPackets: case <-time.After(time.Second): t.Fatalf("timeout") } } require.WithinDuration(t, start, time.Now(), delay/2) for i := 1; i <= numPackets; i++ { select { case packet := <-clientReceivedPackets: require.Equal(t, protocol.PacketNumber(i), readPacketNumber(t, packet)) require.WithinDuration(t, start.Add(time.Duration(i)*delay), time.Now(), delay/2) case <-time.After(time.Second): t.Fatalf("timeout waiting for packet %d", i) } } } func TestProxySwitchConn(t *testing.T) { serverConn := newUPDConnLocalhost(t) type packet struct { Data []byte Addr *net.UDPAddr } serverReceivedPackets := make(chan packet, 1) go func() { for { buf := make([]byte, 1000) n, addr, err := serverConn.ReadFromUDP(buf) if err != nil { return } serverReceivedPackets <- packet{Data: buf[:n], Addr: addr} } }() proxy := Proxy{ Conn: newUPDConnLocalhost(t), ServerAddr: serverConn.LocalAddr().(*net.UDPAddr), } require.NoError(t, proxy.Start()) defer proxy.Close() clientConn := newUPDConnLocalhost(t) _, err := clientConn.WriteToUDP([]byte("hello"), proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) clientConn.SetReadDeadline(time.Now().Add(time.Second)) var firstConnAddr *net.UDPAddr select { case p := <-serverReceivedPackets: require.Equal(t, "hello", string(p.Data)) require.NotEqual(t, clientConn.LocalAddr(), p.Addr) firstConnAddr = p.Addr case <-time.After(time.Second): t.Fatalf("timeout") } _, err = serverConn.WriteToUDP([]byte("hi"), firstConnAddr) require.NoError(t, err) buf := make([]byte, 1000) n, addr, err := clientConn.ReadFromUDP(buf) require.NoError(t, err) require.Equal(t, "hi", string(buf[:n])) require.Equal(t, proxy.LocalAddr(), addr) newConn := newUPDConnLocalhost(t) require.NoError(t, proxy.SwitchConn(clientConn.LocalAddr().(*net.UDPAddr), newConn)) _, err = clientConn.WriteToUDP([]byte("foobar"), proxy.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) select { case p := <-serverReceivedPackets: require.Equal(t, "foobar", string(p.Data)) require.NotEqual(t, clientConn.LocalAddr(), p.Addr) require.NotEqual(t, firstConnAddr, p.Addr) require.Equal(t, newConn.LocalAddr(), p.Addr) case <-time.After(time.Second): t.Fatalf("timeout") } // the old connection doesn't deliver any packets to the client anymore _, err = serverConn.WriteTo([]byte("invalid"), firstConnAddr) require.NoError(t, err) _, err = serverConn.WriteTo([]byte("foobaz"), newConn.LocalAddr()) require.NoError(t, err) n, addr, err = clientConn.ReadFromUDP(buf) require.NoError(t, err) require.Equal(t, "foobaz", string(buf[:n])) // "invalid" is not delivered require.Equal(t, proxy.LocalAddr(), addr) } quic-go-0.59.0/integrationtests/tools/qlog.go000066400000000000000000000026631513066070600212230ustar00rootroot00000000000000package tools import ( "bufio" "context" "fmt" "io" "log" "os" "time" "github.com/quic-go/quic-go" h3qlog "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) func QlogTracer(logger io.Writer) qlogwriter.Trace { filename := fmt.Sprintf("log_%s_transport.qlog", time.Now().Format("2006-01-02T15:04:05")) fmt.Fprintf(logger, "Creating %s.\n", filename) f, err := os.Create(filename) if err != nil { log.Fatalf("failed to create qlog file: %s", err) return nil } bw := bufio.NewWriter(f) fileSeq := qlogwriter.NewFileSeq(utils.NewBufferedWriteCloser(bw, f)) go fileSeq.Run() return fileSeq } func NewQlogConnectionTracer(logger io.Writer) func(ctx context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { return func(_ context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { pers := "server" if isClient { pers = "client" } filename := fmt.Sprintf("log_%s_%s.qlog", connID, pers) fmt.Fprintf(logger, "Creating %s.\n", filename) f, err := os.Create(filename) if err != nil { log.Fatalf("failed to create qlog file: %s", err) return nil } fileSeq := qlogwriter.NewConnectionFileSeq( utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), isClient, connID, []string{qlog.EventSchema, h3qlog.EventSchema}, ) go fileSeq.Run() return fileSeq } } quic-go-0.59.0/integrationtests/versionnegotiation/000077500000000000000000000000001513066070600225115ustar00rootroot00000000000000quic-go-0.59.0/integrationtests/versionnegotiation/handshake_test.go000066400000000000000000000144141513066070600260310ustar00rootroot00000000000000package versionnegotiation import ( "context" "errors" "fmt" "net" "testing" "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func TestServerSupportsMoreVersionsThanClient(t *testing.T) { supportedVersions := append([]quic.Version{}, protocol.SupportedVersions...) protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...) defer func() { protocol.SupportedVersions = supportedVersions }() var serverEventTracer events.Recorder serverConfig := &quic.Config{ Versions: []protocol.Version{7, 8, protocol.SupportedVersions[0], 9}, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: &serverEventTracer} }, } server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) require.NoError(t, err) defer server.Close() var clientEventTracer events.Recorder conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: &clientEventTracer} }}), ) require.NoError(t, err) expectedVersion := protocol.SupportedVersions[0] sconn, err := server.Accept(context.Background()) require.NoError(t, err) require.Equal(t, expectedVersion, sconn.ConnectionState().Version) require.Equal(t, expectedVersion, conn.ConnectionState().Version) require.NoError(t, conn.CloseWithError(0, "")) select { case <-sconn.Context().Done(): // Expected behavior case <-time.After(5 * time.Second): t.Fatal("Timeout waiting for connection to close") } require.Empty(t, clientEventTracer.Events(qlog.VersionNegotiationReceived{})) require.Equal(t, []qlogwriter.Event{ qlog.VersionInformation{ ClientVersions: protocol.SupportedVersions, ChosenVersion: expectedVersion, }, }, clientEventTracer.Events(qlog.VersionInformation{}), ) require.Equal(t, []qlogwriter.Event{ qlog.VersionInformation{ ServerVersions: serverConfig.Versions, ChosenVersion: expectedVersion, }, }, serverEventTracer.Events(qlog.VersionInformation{}), ) } func TestClientSupportsMoreVersionsThanServer(t *testing.T) { supportedVersions := append([]quic.Version{}, protocol.SupportedVersions...) protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...) defer func() { protocol.SupportedVersions = supportedVersions }() expectedVersion := protocol.SupportedVersions[0] // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak var serverEventTracer events.Recorder serverConfig := &quic.Config{ Versions: supportedVersions, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: &serverEventTracer} }, } server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) require.NoError(t, err) defer server.Close() clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10} var clientEventTracer events.Recorder conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: &clientEventTracer} }, }), ) require.NoError(t, err) sconn, err := server.Accept(context.Background()) require.NoError(t, err) require.Equal(t, expectedVersion, sconn.ConnectionState().Version) require.Equal(t, protocol.SupportedVersions[0], conn.ConnectionState().Version) require.NoError(t, conn.CloseWithError(0, "")) select { case <-sconn.Context().Done(): // Expected behavior case <-time.After(5 * time.Second): t.Fatal("Timeout waiting for connection to close") } require.Len(t, clientEventTracer.Events(qlog.VersionNegotiationReceived{}), 1) supportedVersionInclGreased := clientEventTracer.Events(qlog.VersionNegotiationReceived{})[0].(qlog.VersionNegotiationReceived).SupportedVersions require.Equal(t, []qlogwriter.Event{ qlog.VersionInformation{ ClientVersions: clientVersions, ServerVersions: supportedVersionInclGreased, ChosenVersion: expectedVersion, }, }, clientEventTracer.Events(qlog.VersionInformation{}), ) require.Equal(t, []qlogwriter.Event{ qlog.VersionInformation{ ServerVersions: supportedVersions, ChosenVersion: expectedVersion, }, }, serverEventTracer.Events(qlog.VersionInformation{}), ) } func TestServerDisablesVersionNegotiation(t *testing.T) { // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak var serverEventTracer events.Recorder serverConfig := &quic.Config{ Versions: []protocol.Version{quic.Version1}, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: &serverEventTracer} }, } conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) require.NoError(t, err) tr := &quic.Transport{ Conn: conn, DisableVersionNegotiationPackets: true, } ln, err := tr.Listen(getTLSConfig(), serverConfig) require.NoError(t, err) defer ln.Close() var clientEventTracer events.Recorder _, err = quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: []protocol.Version{quic.Version2}, Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return &events.Trace{Recorder: &clientEventTracer} }, HandshakeIdleTimeout: 100 * time.Millisecond, }), ) require.Error(t, err) var nerr net.Error require.True(t, errors.As(err, &nerr)) require.True(t, nerr.Timeout()) require.Empty(t, clientEventTracer.Events(qlog.VersionNegotiationReceived{})) } quic-go-0.59.0/integrationtests/versionnegotiation/rtt_test.go000066400000000000000000000031141513066070600247070ustar00rootroot00000000000000package versionnegotiation import ( "context" "net" "testing" "time" "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) const rtt = 400 * time.Millisecond func expectDurationInRTTs(t *testing.T, startTime time.Time, num int) { t.Helper() testDuration := time.Since(startTime) rtts := float32(testDuration) / float32(rtt) require.GreaterOrEqual(t, rtts, float32(num)) require.Less(t, rtts, float32(num+1)) } func TestVersionNegotiationFailure(t *testing.T) { if len(protocol.SupportedVersions) == 1 { t.Fatal("Test requires at least 2 supported versions.") } serverConfig := &quic.Config{} serverConfig.Versions = protocol.SupportedVersions[:1] ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) require.NoError(t, err) defer ln.Close() proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) defer proxyConn.Close() // start the proxy proxy := quicproxy.Proxy{ Conn: proxyConn, ServerAddr: ln.Addr().(*net.UDPAddr), DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 }, } require.NoError(t, proxy.Start()) defer proxy.Close() startTime := time.Now() _, err = quic.DialAddr( context.Background(), proxy.LocalAddr().String(), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), ) require.Error(t, err) expectDurationInRTTs(t, startTime, 1) } quic-go-0.59.0/integrationtests/versionnegotiation/test_helper_test.go000066400000000000000000000047041513066070600264220ustar00rootroot00000000000000package versionnegotiation import ( "context" "crypto/tls" "crypto/x509" "flag" "os" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/integrationtests/tools" "github.com/quic-go/quic-go/qlogwriter" ) var ( enableQlog bool tlsConfig *tls.Config tlsClientConfig *tls.Config ) func init() { flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") ca, caPrivateKey, err := tools.GenerateCA() if err != nil { panic(err) } leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey) if err != nil { panic(err) } tlsConfig = &tls.Config{ Certificates: []tls.Certificate{{ Certificate: [][]byte{leafCert.Raw}, PrivateKey: leafPrivateKey, }}, NextProtos: []string{tools.ALPN}, } root := x509.NewCertPool() root.AddCert(ca) tlsClientConfig = &tls.Config{ ServerName: "localhost", RootCAs: root, NextProtos: []string{tools.ALPN}, } } func getTLSConfig() *tls.Config { return tlsConfig } func getTLSClientConfig() *tls.Config { return tlsClientConfig } type multiplexedRecorder struct { Recorders []qlogwriter.Recorder } var _ qlogwriter.Recorder = &multiplexedRecorder{} func (r *multiplexedRecorder) Close() error { for _, recorder := range r.Recorders { recorder.Close() } return nil } func (r *multiplexedRecorder) RecordEvent(ev qlogwriter.Event) { for _, recorder := range r.Recorders { recorder.RecordEvent(ev) } } type multiplexedTrace struct { Traces []qlogwriter.Trace } var _ qlogwriter.Trace = &multiplexedTrace{} func (t *multiplexedTrace) SupportsSchemas(schema string) bool { return true } func (t *multiplexedTrace) AddProducer() qlogwriter.Recorder { recorders := make([]qlogwriter.Recorder, 0, len(t.Traces)) for _, tr := range t.Traces { recorders = append(recorders, tr.AddProducer()) } return &multiplexedRecorder{Recorders: recorders} } func maybeAddQLOGTracer(c *quic.Config) *quic.Config { if c == nil { c = &quic.Config{} } if !enableQlog { return c } qlogger := tools.NewQlogConnectionTracer(os.Stdout) if c.Tracer == nil { c.Tracer = qlogger } else if qlogger != nil { origTracer := c.Tracer c.Tracer = func(ctx context.Context, p bool, connID quic.ConnectionID) qlogwriter.Trace { var traces []qlogwriter.Trace if origTracer != nil { traces = append(traces, origTracer(ctx, p, connID)) } if qlogger != nil { traces = append(traces, qlogger(ctx, p, connID)) } return &multiplexedTrace{Traces: traces} } } return c } quic-go-0.59.0/interface.go000066400000000000000000000237451513066070600154570ustar00rootroot00000000000000package quic import ( "context" "crypto/tls" "errors" "net" "slices" "time" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter" ) // The StreamID is the ID of a QUIC stream. type StreamID = protocol.StreamID // A Version is a QUIC version number. type Version = protocol.Version const ( // Version1 is RFC 9000 Version1 = protocol.Version1 // Version2 is RFC 9369 Version2 = protocol.Version2 ) // SupportedVersions returns the support versions, sorted in descending order of preference. func SupportedVersions() []Version { // clone the slice to prevent the caller from modifying the slice return slices.Clone(protocol.SupportedVersions) } // A ClientToken is a token received by the client. // It can be used to skip address validation on future connection attempts. type ClientToken struct { data []byte rtt time.Duration } type TokenStore interface { // Pop searches for a ClientToken associated with the given key. // Since tokens are not supposed to be reused, it must remove the token from the cache. // It returns nil when no token is found. Pop(key string) (token *ClientToken) // Put adds a token to the cache with the given key. It might get called // multiple times in a connection. Put(key string, token *ClientToken) } // Err0RTTRejected is the returned from: // - Open{Uni}Stream{Sync} // - Accept{Uni}Stream // - Stream.Read and Stream.Write // // when the server rejects a 0-RTT connection attempt. var Err0RTTRejected = errors.New("0-RTT rejected") // QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the // context returned by tls.Config.ClientInfo.Context. var QUICVersionContextKey = handshake.QUICVersionContextKey // StatelessResetKey is a key used to derive stateless reset tokens. type StatelessResetKey [32]byte // TokenGeneratorKey is a key used to encrypt session resumption tokens. type TokenGeneratorKey = handshake.TokenProtectorKey // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // It is not able to handle QUIC Connection IDs longer than 20 bytes, // as they are allowed by RFC 8999. type ConnectionID = protocol.ConnectionID // ConnectionIDFromBytes interprets b as a [ConnectionID]. It panics if b is // longer than 20 bytes. func ConnectionIDFromBytes(b []byte) ConnectionID { return protocol.ParseConnectionID(b) } // A ConnectionIDGenerator allows the application to take control over the generation of Connection IDs. // Connection IDs generated by an implementation must be of constant length. type ConnectionIDGenerator interface { // GenerateConnectionID generates a new Connection ID. // Generated Connection IDs must be unique and observers should not be able to correlate two Connection IDs. GenerateConnectionID() (ConnectionID, error) // ConnectionIDLen returns the length of Connection IDs generated by this implementation. // Implementations must return constant-length Connection IDs with lengths between 0 and 20 bytes. // A length of 0 can only be used when an endpoint doesn't need to multiplex connections during migration. ConnectionIDLen() int } // Config contains all configuration data needed for a QUIC server or client. type Config struct { // GetConfigForClient is called for incoming connections. // If the error is not nil, the connection attempt is refused. GetConfigForClient func(info *ClientInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []Version // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // If we don't receive any packet from the peer within this time, the connection attempt is aborted. // Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted. // If this value is zero, the timeout is set to 5 seconds. HandshakeIdleTimeout time.Duration // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. // The actual value for the idle timeout is the minimum of this value and the peer's. // This value only applies after the handshake has completed. // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set // otherwise the token is associated with the server's IP address. TokenStore TokenStore // InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data. // If the application is consuming data quickly enough, the flow control auto-tuning algorithm // will increase the window up to MaxStreamReceiveWindow. // If this value is zero, it will default to 512 KB. // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. InitialStreamReceiveWindow uint64 // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. // If this value is zero, it will default to 6 MB. // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. MaxStreamReceiveWindow uint64 // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. // If the application is consuming data quickly enough, the flow control auto-tuning algorithm // will increase the window up to MaxConnectionReceiveWindow. // If this value is zero, it will default to 512 KB. // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. InitialConnectionReceiveWindow uint64 // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. // If this value is zero, it will default to 15 MB. // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. MaxConnectionReceiveWindow uint64 // AllowConnectionWindowIncrease is called every time the connection flow controller attempts // to increase the connection flow control window. // If set, the caller can prevent an increase of the window. Typically, it would do so to // limit the memory usage. // To avoid deadlocks, it is not valid to call other functions on the connection or on streams // in this callback. AllowConnectionWindowIncrease func(conn *Conn, delta uint64) bool // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. // If not set, it will default to 100. // If set to a negative value, it doesn't allow any bidirectional streams. // Values larger than 2^60 will be clipped to that value. MaxIncomingStreams int64 // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. // If not set, it will default to 100. // If set to a negative value, it doesn't allow any unidirectional streams. // Values larger than 2^60 will be clipped to that value. MaxIncomingUniStreams int64 // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // every half of MaxIdleTimeout, whichever is smaller). KeepAlivePeriod time.Duration // InitialPacketSize is the initial size (and the lower limit) for packets sent. // Under most circumstances, it is not necessary to manually set this value, // since path MTU discovery quickly finds the path's MTU. // If set too high, the path might not support packets of that size, leading to a timeout of the QUIC handshake. // Values below 1200 are invalid. InitialPacketSize uint16 // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). // This allows the sending of QUIC packets that fully utilize the available MTU of the path. // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. DisablePathMTUDiscovery bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Only valid for the server. Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool // Enable QUIC Stream Resets with Partial Delivery. // See https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07. EnableStreamResetPartialDelivery bool Tracer func(ctx context.Context, isClient bool, connID ConnectionID) qlogwriter.Trace } // ClientInfo contains information about an incoming connection attempt. type ClientInfo struct { // RemoteAddr is the remote address on the Initial packet. // Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address. RemoteAddr net.Addr // AddrVerified says if the remote address was verified using QUIC's Retry mechanism. // Note that the Retry mechanism costs one network roundtrip, // and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed. AddrVerified bool } // ConnectionState records basic details about a QUIC connection. type ConnectionState struct { // TLS contains information about the TLS connection state, incl. the tls.ConnectionState. TLS tls.ConnectionState // SupportsDatagrams indicates support for QUIC datagrams (RFC 9221). SupportsDatagrams struct { // Remote is true if the peer advertised datagram support. // Local is true if datagram support was enabled via Config.EnableDatagrams. Remote, Local bool } // SupportsStreamResetPartialDelivery indicates support for QUIC Stream Resets with Partial Delivery. SupportsStreamResetPartialDelivery struct { // Remote is true if the peer advertised support. // Local is true if support was enabled via Config.EnableStreamResetPartialDelivery. Remote, Local bool } // Used0RTT says if 0-RTT resumption was used. Used0RTT bool // Version is the QUIC version of the QUIC connection. Version Version // GSO says if generic segmentation offload is used. GSO bool } quic-go-0.59.0/internal/000077500000000000000000000000001513066070600147715ustar00rootroot00000000000000quic-go-0.59.0/internal/ackhandler/000077500000000000000000000000001513066070600170655ustar00rootroot00000000000000quic-go-0.59.0/internal/ackhandler/ack_eliciting.go000066400000000000000000000016121513066070600222010ustar00rootroot00000000000000package ackhandler import "github.com/quic-go/quic-go/internal/wire" // IsFrameTypeAckEliciting returns true if the frame is ack-eliciting. func IsFrameTypeAckEliciting(t wire.FrameType) bool { //nolint:exhaustive // The default case catches the rest. switch t { case wire.FrameTypeAck, wire.FrameTypeAckECN: return false case wire.FrameTypeConnectionClose, wire.FrameTypeApplicationClose: return false default: return true } } // IsFrameAckEliciting returns true if the frame is ack-eliciting. func IsFrameAckEliciting(f wire.Frame) bool { _, isAck := f.(*wire.AckFrame) _, isConnectionClose := f.(*wire.ConnectionCloseFrame) return !isAck && !isConnectionClose } // HasAckElicitingFrames returns true if at least one frame is ack-eliciting. func HasAckElicitingFrames(fs []Frame) bool { for _, f := range fs { if IsFrameAckEliciting(f.Frame) { return true } } return false } quic-go-0.59.0/internal/ackhandler/ack_eliciting_test.go000066400000000000000000000047671513066070600232560ustar00rootroot00000000000000package ackhandler import ( "testing" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestIsFrameTypeAckEliciting(t *testing.T) { testCases := map[wire.FrameType]bool{ wire.FrameTypePing: true, wire.FrameTypeAck: false, wire.FrameTypeAckECN: false, wire.FrameTypeResetStream: true, wire.FrameTypeStopSending: true, wire.FrameTypeCrypto: true, wire.FrameTypeNewToken: true, wire.FrameType(0x08): true, wire.FrameType(0x09): true, wire.FrameType(0x0a): true, wire.FrameType(0x0b): true, wire.FrameType(0x0c): true, wire.FrameType(0x0d): true, wire.FrameType(0x0e): true, wire.FrameType(0x0f): true, wire.FrameTypeMaxData: true, wire.FrameTypeMaxStreamData: true, wire.FrameTypeBidiMaxStreams: true, wire.FrameTypeUniMaxStreams: true, wire.FrameTypeDataBlocked: true, wire.FrameTypeStreamDataBlocked: true, wire.FrameTypeBidiStreamBlocked: true, wire.FrameTypeUniStreamBlocked: true, wire.FrameTypeNewConnectionID: true, wire.FrameTypeRetireConnectionID: true, wire.FrameTypePathChallenge: true, wire.FrameTypePathResponse: true, wire.FrameTypeConnectionClose: false, wire.FrameTypeApplicationClose: false, wire.FrameTypeHandshakeDone: true, wire.FrameTypeResetStreamAt: true, wire.FrameTypeDatagramNoLength: true, wire.FrameTypeDatagramWithLength: true, wire.FrameTypeAckFrequency: true, wire.FrameTypeImmediateAck: true, } for ft, expected := range testCases { require.Equal(t, expected, IsFrameTypeAckEliciting(ft), "unexpected result for frame type 0x%x", ft) } } func TestAckElicitingFrames(t *testing.T) { testCases := map[wire.Frame]bool{ &wire.AckFrame{}: false, &wire.ConnectionCloseFrame{}: false, &wire.DataBlockedFrame{}: true, &wire.PingFrame{}: true, &wire.ResetStreamFrame{}: true, &wire.StreamFrame{}: true, &wire.DatagramFrame{}: true, &wire.MaxDataFrame{}: true, &wire.MaxStreamDataFrame{}: true, &wire.StopSendingFrame{}: true, &wire.AckFrequencyFrame{}: true, &wire.ImmediateAckFrame{}: true, } for f, expected := range testCases { require.Equal(t, expected, IsFrameAckEliciting(f)) require.Equal(t, expected, HasAckElicitingFrames([]Frame{{Frame: f}})) } } quic-go-0.59.0/internal/ackhandler/ecn.go000066400000000000000000000260521513066070600201660ustar00rootroot00000000000000package ackhandler import ( "fmt" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) type ecnState uint8 const ( ecnStateInitial ecnState = iota ecnStateTesting ecnStateUnknown ecnStateCapable ecnStateFailed ) const ( // ecnFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets, // but doesn't contain any ECN counts ecnFailedNoECNCounts = "ACK doesn't contain ECN marks" // ecnFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts ecnFailedDecreasedECNCounts = "ACK decreases ECN counts" // ecnFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost ecnFailedLostAllTestingPackets = "all ECN testing packets declared lost" // ecnFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent ecnFailedMoreECNCountsThanSent = "ACK contains more ECN counts than ECN-marked packets sent" // ecnFailedTooFewECNCounts is emitted when an ACK contains fewer ECN counts than it acknowledges packets ecnFailedTooFewECNCounts = "ACK contains fewer new ECN counts than acknowledged ECN-marked packets" // ecnFailedManglingDetected is emitted when the path marks all ECN-marked packets as CE ecnFailedManglingDetected = "ECN mangling detected" ) // must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type const numECNTestingPackets = 10 type ecnHandler interface { SentPacket(protocol.PacketNumber, protocol.ECN) Mode() protocol.ECN HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) (congested bool) LostPacket(protocol.PacketNumber) } // The ecnTracker performs ECN validation of a path. // Once failed, it doesn't do any re-validation of the path. // It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces. // In order to avoid revealing any internal state to on-path observers, // callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent. // The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4. type ecnTracker struct { state ecnState numSentTesting, numLostTesting uint8 firstTestingPacket protocol.PacketNumber lastTestingPacket protocol.PacketNumber firstCapablePacket protocol.PacketNumber numSentECT0, numSentECT1 int64 numAckedECT0, numAckedECT1, numAckedECNCE int64 qlogger qlogwriter.Recorder logger utils.Logger } var _ ecnHandler = &ecnTracker{} func newECNTracker(logger utils.Logger, qlogger qlogwriter.Recorder) *ecnTracker { return &ecnTracker{ firstTestingPacket: protocol.InvalidPacketNumber, lastTestingPacket: protocol.InvalidPacketNumber, firstCapablePacket: protocol.InvalidPacketNumber, state: ecnStateInitial, logger: logger, qlogger: qlogger, } } func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { //nolint:exhaustive // These are the only ones we need to take care of. switch ecn { case protocol.ECNNon: return case protocol.ECT0: e.numSentECT0++ case protocol.ECT1: e.numSentECT1++ case protocol.ECNUnsupported: if e.state != ecnStateFailed { panic("didn't expect ECN to be unsupported") } default: panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn)) } if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber { e.firstCapablePacket = pn } if e.state != ecnStateTesting { return } e.numSentTesting++ if e.firstTestingPacket == protocol.InvalidPacketNumber { e.firstTestingPacket = pn } if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets { if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateUnknown, }) } e.state = ecnStateUnknown e.lastTestingPacket = pn } } func (e *ecnTracker) Mode() protocol.ECN { switch e.state { case ecnStateInitial: if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateTesting, }) } e.state = ecnStateTesting return e.Mode() case ecnStateTesting, ecnStateCapable: return protocol.ECT0 case ecnStateUnknown, ecnStateFailed: return protocol.ECNNon default: panic(fmt.Sprintf("unknown ECN state: %d", e.state)) } } func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { if e.state != ecnStateTesting && e.state != ecnStateUnknown { return } if !e.isTestingPacket(pn) { return } e.numLostTesting++ // Only proceed if we have sent all 10 testing packets. if e.state != ecnStateUnknown { return } if e.numLostTesting >= e.numSentTesting { e.logger.Debugf("Disabling ECN. All testing packets were lost.") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedLostAllTestingPackets, }) } e.state = ecnStateFailed return } // Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked e.failIfMangled() } // HandleNewlyAcked handles the ECN counts on an ACK frame. // It must only be called for ACK frames that increase the largest acknowledged packet number, // see section 13.4.2.1 of RFC 9000. func (e *ecnTracker) HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) (congested bool) { if e.state == ecnStateFailed { return false } // ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds // the total number of packets sent with each corresponding ECT codepoint. if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 { e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedMoreECNCountsThanSent, }) } e.state = ecnStateFailed return false } // Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged. var ackedECT0, ackedECT1 int64 for _, p := range packets { //nolint:exhaustive // We only ever send ECT(0) and ECT(1). switch e.ecnMarking(p.PacketNumber) { case protocol.ECT0: ackedECT0++ case protocol.ECT1: ackedECT1++ } } // If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1) // codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame. // This check detects: // * paths that bleach all ECN marks, and // * peers that don't report any ECN counts if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 { e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedNoECNCounts, }) } e.state = ecnStateFailed return false } // Determine the increase in ECT0, ECT1 and ECNCE marks newECT0 := ect0 - e.numAckedECT0 newECT1 := ect1 - e.numAckedECT1 newECNCE := ecnce - e.numAckedECNCE // We're only processing ACKs that increase the Largest Acked. // Therefore, the ECN counters should only ever increase. // Any decrease means that the peer's counting logic is broken. if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 { e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedDecreasedECNCounts, }) } e.state = ecnStateFailed return false } // ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number // of newly acknowledged packets that were originally sent with an ECT(0) marking. // This could be the result of (partial) bleaching. if newECT0+newECNCE < ackedECT0 { e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedTooFewECNCounts, }) } e.state = ecnStateFailed return false } // Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than // the number of newly acknowledged packets sent with an ECT(1) marking. if newECT1+newECNCE < ackedECT1 { e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedTooFewECNCounts, }) } e.state = ecnStateFailed return false } // update our counters e.numAckedECT0 = ect0 e.numAckedECT1 = ect1 e.numAckedECNCE = ecnce // Detect mangling (a path remarking all ECN-marked testing packets as CE), // once all 10 testing packets have been sent out. if e.state == ecnStateUnknown { e.failIfMangled() if e.state == ecnStateFailed { return false } } if e.state == ecnStateTesting || e.state == ecnStateUnknown { var ackedTestingPacket bool for _, p := range packets { if e.isTestingPacket(p.PacketNumber) { ackedTestingPacket = true break } } // This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE). if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) { e.logger.Debugf("ECN capability confirmed.") if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateCapable, }) } e.state = ecnStateCapable } } // Don't trust CE marks before having confirmed ECN capability of the path. // Otherwise, mangling would be misinterpreted as actual congestion. return e.state == ecnStateCapable && newECNCE > 0 } // failIfMangled fails ECN validation if all testing packets are lost or CE-marked. func (e *ecnTracker) failIfMangled() { numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting) if e.numSentECT0+e.numSentECT1 > numAckedECNCE { return } if e.qlogger != nil { e.qlogger.RecordEvent(qlog.ECNStateUpdated{ State: qlog.ECNStateFailed, Trigger: ecnFailedManglingDetected, }) } e.state = ecnStateFailed } func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber { return protocol.ECNNon } if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber { return protocol.ECT0 } if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber { return protocol.ECNNon } // We don't need to deal with the case when ECN validation fails, // since we're ignoring any ECN counts reported in ACK frames in that case. return protocol.ECT0 } func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool { if e.firstTestingPacket == protocol.InvalidPacketNumber { return false } return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber) } quic-go-0.59.0/internal/ackhandler/ecn_test.go000066400000000000000000000304701513066070600212240ustar00rootroot00000000000000package ackhandler import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func getAckedPackets(pns ...protocol.PacketNumber) []packetWithPacketNumber { var packets []packetWithPacketNumber for _, p := range pns { packets = append(packets, packetWithPacketNumber{PacketNumber: p}) } return packets } // sendECNTestingPackets sends 10 ECT(0) packets, and then one more packet // Packet numbers: 0 through 9. func sendECNTestingPackets(t *testing.T, ecnTracker *ecnTracker, recorder *events.Recorder) { t.Helper() for i := range protocol.PacketNumber(9) { require.Equal(t, protocol.ECT0, ecnTracker.Mode()) // do this twice to make sure only sent packets are counted require.Equal(t, protocol.ECT0, ecnTracker.Mode()) ecnTracker.SentPacket(i, protocol.ECT0) } require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateTesting}}, recorder.Events(), ) require.Equal(t, protocol.ECT0, ecnTracker.Mode()) recorder.Clear() ecnTracker.SentPacket(9, protocol.ECT0) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateUnknown}}, recorder.Events(), ) recorder.Clear() // in unknown state, packets shouldn't be ECN-marked require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) } // ECN validation fails if *all* ECN testing packets are lost. func TestECNTestingPacketsLoss(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) // send non-testing packets for i := range protocol.PacketNumber(10) { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(10+i, protocol.ECNNon) } // lose all but one packet for pn := range protocol.PacketNumber(10) { if pn == 4 { continue } ecnTracker.LostPacket(pn) } // loss of non-testing packets doesn't matter ecnTracker.LostPacket(13) ecnTracker.LostPacket(14) // now lose the last testing packet require.Empty(t, eventRecorder.Events()) eventRecorder.Clear() ecnTracker.LostPacket(4) require.Equal(t, []qlogwriter.Event{ qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: ecnFailedLostAllTestingPackets}, }, eventRecorder.Events(), ) } // ECN support is validated once an acknowledgment for any testing packet is received. // This applies even if that happens before all testing packets have been sent out. func TestECNValidationInTestingState(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) for i := range 5 { require.Equal(t, protocol.ECT0, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) } require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateTesting}}, eventRecorder.Events(), ) eventRecorder.Clear() require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(3), 1, 0, 0)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateCapable}}, eventRecorder.Events(), ) // make sure we continue sending ECT(0) packets for i := 5; i < 100; i++ { require.Equal(t, protocol.ECT0, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0) } } // ENC is also validated after all testing packets have been sent out, // once an acknowledgment for any testing packet is received. func TestECNValidationInUnknownState(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := range protocol.PacketNumber(10) { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) pn := 10 + i ecnTracker.SentPacket(pn, protocol.ECNNon) // lose some packets to make sure this doesn't influence the outcome. if i%2 == 0 { ecnTracker.LostPacket(pn) } } require.Empty(t, eventRecorder.Events()) require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(7), 1, 0, 0)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateCapable}}, eventRecorder.Events(), ) } func TestECNValidationFailures(t *testing.T) { t.Run("ECN bleaching", func(t *testing.T) { // this ACK doesn't contain any ECN counts testECNValidationFailure(t, getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 0, 0, 0, ecnFailedNoECNCounts) }) t.Run("wrong ECN code point", func(t *testing.T) { // we sent ECT(0), but this ACK acknowledges ECT(1) testECNValidationFailure(t, getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 0, 1, 0, ecnFailedMoreECNCountsThanSent) }) t.Run("more ECN counts than sent packets", func(t *testing.T) { // only 10 ECT(0) packets were sent, but the ACK claims to have received 12 of them testECNValidationFailure(t, getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 12, 0, 0, ecnFailedMoreECNCountsThanSent) }) } func testECNValidationFailure( t *testing.T, ackedPackets []packetWithPacketNumber, ect0, ect1, ecnce int64, expectedTrigger string, ) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } require.False(t, ecnTracker.HandleNewlyAcked(ackedPackets, ect0, ect1, ecnce)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: expectedTrigger}}, eventRecorder.Events(), ) } func TestECNValidationNotEnoughECNCounts(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } require.Empty(t, eventRecorder.Events()) // First only acknowledge some packets sent with ECN marks. require.True(t, ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateCapable}}, eventRecorder.Events(), ) eventRecorder.Clear() // Now acknowledge some more packets sent with ECN marks, but don't increase the counters enough. // This ACK acknowledges 3 more ECN-marked packets, but the counters only increase by 2. require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: ecnFailedTooFewECNCounts}}, eventRecorder.Events(), ) } func TestECNNonsensicalECNCountDecrease(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } require.Empty(t, eventRecorder.Events()) require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 3, 0, 0)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateCapable}}, eventRecorder.Events(), ) eventRecorder.Clear() // Now acknowledge some more packets, but decrease the ECN counts. Obviously, this doesn't make any sense. require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 2, 0, 0)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: ecnFailedDecreasedECNCounts}}, eventRecorder.Events(), ) eventRecorder.Clear() // make sure that new ACKs are ignored require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 5, 0, 0)) require.Empty(t, eventRecorder.Events()) } func TestECNACKReordering(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } require.Empty(t, eventRecorder.Events()) // The ACK contains more ECN counts than it acknowledges packets. // This can happen if ACKs are lost / reordered. require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 8, 0, 0)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateCapable}}, eventRecorder.Events(), ) } // Mangling is detected if all testing packets are marked CE. func TestECNManglingAllPacketsMarkedCE(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3), 0, 0, 4)) require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 10, 11, 12), 0, 0, 7)) require.Empty(t, eventRecorder.Events()) // With the next ACK, all testing packets will now have been marked CE. require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 13), 0, 0, 10)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: ecnFailedManglingDetected}}, eventRecorder.Events(), ) } // Mangling is also detected if some testing packets are lost, and then others are marked CE. func TestECNManglingSomePacketsLostSomeMarkedCE(t *testing.T) { t.Run("packet loss first", func(t *testing.T) { testECNManglingSomePacketsLostSomeMarkedCE(t, true) }) t.Run("CE marking first", func(t *testing.T) { testECNManglingSomePacketsLostSomeMarkedCE(t, false) }) } func testECNManglingSomePacketsLostSomeMarkedCE(t *testing.T, packetLossFirst bool) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } // Lose a few packets. if packetLossFirst { ecnTracker.LostPacket(0) ecnTracker.LostPacket(1) ecnTracker.LostPacket(2) } // ECN capability not confirmed yet, therefore CE marks are not regarded as congestion events require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(3, 4, 5, 6, 7, 8), 0, 0, 6)) require.Empty(t, eventRecorder.Events()) // By CE-marking the last unacknowledged testing packets, we should detect the mangling. require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(9), 0, 0, 7)) if packetLossFirst { require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: ecnFailedManglingDetected}}, eventRecorder.Events(), ) } else { require.Empty(t, eventRecorder.Events()) } if !packetLossFirst { ecnTracker.LostPacket(0) ecnTracker.LostPacket(1) ecnTracker.LostPacket(2) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateFailed, Trigger: ecnFailedManglingDetected}}, eventRecorder.Events(), ) } } func TestECNCongestionDetection(t *testing.T) { var eventRecorder events.Recorder ecnTracker := newECNTracker(utils.DefaultLogger, &eventRecorder) sendECNTestingPackets(t, ecnTracker, &eventRecorder) for i := 10; i < 20; i++ { require.Equal(t, protocol.ECNNon, ecnTracker.Mode()) ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon) } // Receive one CE count. require.True(t, ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)) require.Equal(t, []qlogwriter.Event{qlog.ECNStateUpdated{State: qlog.ECNStateCapable}}, eventRecorder.Events(), ) // No increase in CE. No congestion. require.False(t, ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 5, 0, 1)) eventRecorder.Clear() // Increase in CE. More congestion. require.True(t, ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 7, 0, 2)) require.Empty(t, eventRecorder.Events()) } quic-go-0.59.0/internal/ackhandler/frame.go000066400000000000000000000006551513066070600205140ustar00rootroot00000000000000package ackhandler import ( "github.com/quic-go/quic-go/internal/wire" ) // FrameHandler handles the acknowledgement and the loss of a frame. type FrameHandler interface { OnAcked(wire.Frame) OnLost(wire.Frame) } type Frame struct { Frame wire.Frame // nil if the frame has already been acknowledged in another packet Handler FrameHandler } type StreamFrame struct { Frame *wire.StreamFrame Handler FrameHandler } quic-go-0.59.0/internal/ackhandler/interfaces.go000066400000000000000000000034621513066070600215440ustar00rootroot00000000000000package ackhandler import ( "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(t monotime.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket, isPathProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime monotime.Time) (bool /* 1-RTT packet acked */, error) ReceivedPacket(protocol.EncryptionLevel, monotime.Time) ReceivedBytes(_ protocol.ByteCount, rcvTime monotime.Time) DropPackets(_ protocol.EncryptionLevel, rcvTime monotime.Time) ResetForRetry(rcvTime monotime.Time) // The SendMode determines if and what kind of packets can be sent. SendMode(now monotime.Time) SendMode // TimeUntilSend is the time when the next packet should be sent. // It is used for pacing packets. TimeUntilSend() monotime.Time SetMaxDatagramSize(count protocol.ByteCount) // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber GetLossDetectionTimeout() monotime.Time OnLossDetectionTimeout(now monotime.Time) error MigratedPath(now monotime.Time, initialMaxPacketSize protocol.ByteCount) } quic-go-0.59.0/internal/ackhandler/lost_packet_tracker.go000066400000000000000000000034031513066070600234370ustar00rootroot00000000000000package ackhandler import ( "iter" "slices" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" ) type lostPacket struct { PacketNumber protocol.PacketNumber SendTime monotime.Time } type lostPacketTracker struct { maxLength int lostPackets []lostPacket } func newLostPacketTracker(maxLength int) *lostPacketTracker { return &lostPacketTracker{ maxLength: maxLength, // Preallocate a small slice only. // Hopefully we won't lose many packets. lostPackets: make([]lostPacket, 0, 4), } } func (t *lostPacketTracker) Add(p protocol.PacketNumber, sendTime monotime.Time) { if len(t.lostPackets) == t.maxLength { t.lostPackets = t.lostPackets[1:] } t.lostPackets = append(t.lostPackets, lostPacket{ PacketNumber: p, SendTime: sendTime, }) } // Delete deletes a packet from the lost packet tracker. // This function is not optimized for performance if many packets are lost, // but it is only used when a spurious loss is detected, which is rare. func (t *lostPacketTracker) Delete(pn protocol.PacketNumber) { t.lostPackets = slices.DeleteFunc(t.lostPackets, func(p lostPacket) bool { return p.PacketNumber == pn }) } func (t *lostPacketTracker) All() iter.Seq2[protocol.PacketNumber, monotime.Time] { return func(yield func(protocol.PacketNumber, monotime.Time) bool) { for _, p := range t.lostPackets { if !yield(p.PacketNumber, p.SendTime) { return } } } } func (t *lostPacketTracker) DeleteBefore(ti monotime.Time) { if len(t.lostPackets) == 0 { return } if !t.lostPackets[0].SendTime.Before(ti) { return } var idx int for ; idx < len(t.lostPackets); idx++ { if !t.lostPackets[idx].SendTime.Before(ti) { break } } t.lostPackets = slices.Delete(t.lostPackets, 0, idx) } quic-go-0.59.0/internal/ackhandler/lost_packet_tracker_test.go000066400000000000000000000037541513066070600245070ustar00rootroot00000000000000package ackhandler import ( "maps" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestLostPacketTracker(t *testing.T) { lt := newLostPacketTracker(4) start := monotime.Now() lt.Add(1, start) lt.Add(5, start.Add(time.Second)) lt.Add(8, start.Add(2*time.Second)) require.Equal(t, map[protocol.PacketNumber]monotime.Time{ 1: start, 5: start.Add(time.Second), 8: start.Add(2 * time.Second), }, maps.Collect(lt.All())) // Lose 2 more packets. The first one should be removed. lt.Add(10, start.Add(3*time.Second)) lt.Add(11, start.Add(4*time.Second)) require.Equal(t, map[protocol.PacketNumber]monotime.Time{ 5: start.Add(time.Second), 8: start.Add(2 * time.Second), 10: start.Add(3 * time.Second), 11: start.Add(4 * time.Second), }, maps.Collect(lt.All())) lt.Delete(5) lt.Delete(10) require.Equal(t, map[protocol.PacketNumber]monotime.Time{ 8: start.Add(2 * time.Second), 11: start.Add(4 * time.Second), }, maps.Collect(lt.All())) } func TestLostPacketTrackerDeleteBefore(t *testing.T) { lt := newLostPacketTracker(4) trackedPackets := func(lt *lostPacketTracker) []protocol.PacketNumber { var pns []protocol.PacketNumber for pn := range lt.All() { pns = append(pns, pn) } return pns } start := monotime.Now() lt.Add(1, start) lt.Add(5, start.Add(time.Second)) lt.Add(8, start.Add(2*time.Second)) lt.Add(10, start.Add(3*time.Second)) require.Equal(t, []protocol.PacketNumber{1, 5, 8, 10}, trackedPackets(lt)) lt.DeleteBefore(start) // this should be a no-op require.Equal(t, []protocol.PacketNumber{1, 5, 8, 10}, trackedPackets(lt)) lt.DeleteBefore(start.Add(2 * time.Second)) require.Equal(t, []protocol.PacketNumber{8, 10}, trackedPackets(lt)) lt.DeleteBefore(start.Add(time.Second * 5 / 2)) require.Equal(t, []protocol.PacketNumber{10}, trackedPackets(lt)) lt.DeleteBefore(start.Add(time.Hour)) require.Empty(t, trackedPackets(lt)) } quic-go-0.59.0/internal/ackhandler/mock_ecn_handler_test.go000066400000000000000000000140041513066070600237250ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: ECNHandler) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler // // Package ackhandler is a generated GoMock package. package ackhandler import ( reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockECNHandler is a mock of ECNHandler interface. type MockECNHandler struct { ctrl *gomock.Controller recorder *MockECNHandlerMockRecorder isgomock struct{} } // MockECNHandlerMockRecorder is the mock recorder for MockECNHandler. type MockECNHandlerMockRecorder struct { mock *MockECNHandler } // NewMockECNHandler creates a new mock instance. func NewMockECNHandler(ctrl *gomock.Controller) *MockECNHandler { mock := &MockECNHandler{ctrl: ctrl} mock.recorder = &MockECNHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockECNHandler) EXPECT() *MockECNHandlerMockRecorder { return m.recorder } // HandleNewlyAcked mocks base method. func (m *MockECNHandler) HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandleNewlyAcked", packets, ect0, ect1, ecnce) ret0, _ := ret[0].(bool) return ret0 } // HandleNewlyAcked indicates an expected call of HandleNewlyAcked. func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(packets, ect0, ect1, ecnce any) *MockECNHandlerHandleNewlyAckedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNewlyAcked", reflect.TypeOf((*MockECNHandler)(nil).HandleNewlyAcked), packets, ect0, ect1, ecnce) return &MockECNHandlerHandleNewlyAckedCall{Call: call} } // MockECNHandlerHandleNewlyAckedCall wrap *gomock.Call type MockECNHandlerHandleNewlyAckedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockECNHandlerHandleNewlyAckedCall) Return(congested bool) *MockECNHandlerHandleNewlyAckedCall { c.Call = c.Call.Return(congested) return c } // Do rewrite *gomock.Call.Do func (c *MockECNHandlerHandleNewlyAckedCall) Do(f func([]packetWithPacketNumber, int64, int64, int64) bool) *MockECNHandlerHandleNewlyAckedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockECNHandlerHandleNewlyAckedCall) DoAndReturn(f func([]packetWithPacketNumber, int64, int64, int64) bool) *MockECNHandlerHandleNewlyAckedCall { c.Call = c.Call.DoAndReturn(f) return c } // LostPacket mocks base method. func (m *MockECNHandler) LostPacket(arg0 protocol.PacketNumber) { m.ctrl.T.Helper() m.ctrl.Call(m, "LostPacket", arg0) } // LostPacket indicates an expected call of LostPacket. func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 any) *MockECNHandlerLostPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockECNHandler)(nil).LostPacket), arg0) return &MockECNHandlerLostPacketCall{Call: call} } // MockECNHandlerLostPacketCall wrap *gomock.Call type MockECNHandlerLostPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockECNHandlerLostPacketCall) Return() *MockECNHandlerLostPacketCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockECNHandlerLostPacketCall) Do(f func(protocol.PacketNumber)) *MockECNHandlerLostPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockECNHandlerLostPacketCall) DoAndReturn(f func(protocol.PacketNumber)) *MockECNHandlerLostPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // Mode mocks base method. func (m *MockECNHandler) Mode() protocol.ECN { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Mode") ret0, _ := ret[0].(protocol.ECN) return ret0 } // Mode indicates an expected call of Mode. func (mr *MockECNHandlerMockRecorder) Mode() *MockECNHandlerModeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mode", reflect.TypeOf((*MockECNHandler)(nil).Mode)) return &MockECNHandlerModeCall{Call: call} } // MockECNHandlerModeCall wrap *gomock.Call type MockECNHandlerModeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockECNHandlerModeCall) Return(arg0 protocol.ECN) *MockECNHandlerModeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockECNHandlerModeCall) Do(f func() protocol.ECN) *MockECNHandlerModeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockECNHandlerModeCall) DoAndReturn(f func() protocol.ECN) *MockECNHandlerModeCall { c.Call = c.Call.DoAndReturn(f) return c } // SentPacket mocks base method. func (m *MockECNHandler) SentPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN) { m.ctrl.T.Helper() m.ctrl.Call(m, "SentPacket", arg0, arg1) } // SentPacket indicates an expected call of SentPacket. func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 any) *MockECNHandlerSentPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockECNHandler)(nil).SentPacket), arg0, arg1) return &MockECNHandlerSentPacketCall{Call: call} } // MockECNHandlerSentPacketCall wrap *gomock.Call type MockECNHandlerSentPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockECNHandlerSentPacketCall) Return() *MockECNHandlerSentPacketCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockECNHandlerSentPacketCall) Do(f func(protocol.PacketNumber, protocol.ECN)) *MockECNHandlerSentPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockECNHandlerSentPacketCall) DoAndReturn(f func(protocol.PacketNumber, protocol.ECN)) *MockECNHandlerSentPacketCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/ackhandler/mockgen.go000066400000000000000000000004201513066070600210330ustar00rootroot00000000000000//go:build gomock || generate package ackhandler //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler" type ECNHandler = ecnHandler quic-go-0.59.0/internal/ackhandler/packet.go000066400000000000000000000031051513066070600206620ustar00rootroot00000000000000package ackhandler import ( "sync" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" ) type packetWithPacketNumber struct { PacketNumber protocol.PacketNumber *packet } // A Packet is a packet type packet struct { SendTime monotime.Time StreamFrames []StreamFrame Frames []Frame LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. includedInBytesInFlight bool isPathProbePacket bool } func (p *packet) Outstanding() bool { return !p.IsPathMTUProbePacket && !p.isPathProbePacket && p.IsAckEliciting() } func (p *packet) IsAckEliciting() bool { return len(p.StreamFrames) > 0 || len(p.Frames) > 0 } var packetPool = sync.Pool{New: func() any { return &packet{} }} func getPacket() *packet { p := packetPool.Get().(*packet) p.StreamFrames = nil p.Frames = nil p.LargestAcked = 0 p.Length = 0 p.EncryptionLevel = protocol.EncryptionLevel(0) p.SendTime = 0 p.IsPathMTUProbePacket = false p.includedInBytesInFlight = false p.isPathProbePacket = false return p } // We currently only return Packets back into the pool when they're acknowledged (not when they're lost). // This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool. func putPacket(p *packet) { p.Frames = nil p.StreamFrames = nil packetPool.Put(p) } quic-go-0.59.0/internal/ackhandler/packet_number_generator.go000066400000000000000000000045101513066070600243010ustar00rootroot00000000000000package ackhandler import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) type packetNumberGenerator interface { Peek() protocol.PacketNumber // Pop pops the packet number. // It reports if the packet number (before the one just popped) was skipped. // It never skips more than one packet number in a row. Pop() (skipped bool, _ protocol.PacketNumber) } type sequentialPacketNumberGenerator struct { next protocol.PacketNumber } var _ packetNumberGenerator = &sequentialPacketNumberGenerator{} func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator { return &sequentialPacketNumberGenerator{next: initial} } func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { return p.next } func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { next := p.next p.next++ return false, next } // The skippingPacketNumberGenerator generates the packet number for the next packet // it randomly skips a packet number every averagePeriod packets (on average). // It is guaranteed to never skip two consecutive packet numbers. type skippingPacketNumberGenerator struct { period protocol.PacketNumber maxPeriod protocol.PacketNumber next protocol.PacketNumber nextToSkip protocol.PacketNumber rng utils.Rand } var _ packetNumberGenerator = &skippingPacketNumberGenerator{} func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator { g := &skippingPacketNumberGenerator{ next: initial, period: initialPeriod, maxPeriod: maxPeriod, } g.generateNewSkip() return g } func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { if p.next == p.nextToSkip { return p.next + 1 } return p.next } func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { next := p.next if p.next == p.nextToSkip { next++ p.next += 2 p.generateNewSkip() return true, next } p.next++ // generate a new packet number for the next packet return false, next } func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) p.period = min(2*p.period, p.maxPeriod) } quic-go-0.59.0/internal/ackhandler/packet_number_generator_test.go000066400000000000000000000047631513066070600253520ustar00rootroot00000000000000package ackhandler import ( "math" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestSequentialPacketNumberGenerator(t *testing.T) { const initialPN protocol.PacketNumber = 123 png := newSequentialPacketNumberGenerator(initialPN) for i := initialPN; i < initialPN+1000; i++ { require.Equal(t, i, png.Peek()) require.Equal(t, i, png.Peek()) skipNext, pn := png.Pop() require.False(t, skipNext) require.Equal(t, i, pn) } } func TestSkippingPacketNumberGenerator(t *testing.T) { // the maximum period must be sufficiently small such that using a 32-bit random number is ok require.Less(t, 2*protocol.SkipPacketMaxPeriod, protocol.PacketNumber(math.MaxInt32)) const initialPeriod protocol.PacketNumber = 25 const maxPeriod protocol.PacketNumber = 300 png := newSkippingPacketNumberGenerator(100, initialPeriod, maxPeriod) require.Equal(t, protocol.PacketNumber(100), png.Peek()) require.Equal(t, protocol.PacketNumber(100), png.Peek()) require.Equal(t, protocol.PacketNumber(100), png.Peek()) _, pn := png.Pop() require.Equal(t, protocol.PacketNumber(100), pn) var last protocol.PacketNumber var skipped bool for i := range maxPeriod { didSkip, num := png.Pop() if didSkip { skipped = true _, nextNum := png.Pop() require.Equal(t, num+1, nextNum) break } if i != 0 { require.Equal(t, num, last+1) } last = num } require.True(t, skipped) } func TestSkippingPacketNumberGeneratorPeriods(t *testing.T) { const initialPN protocol.PacketNumber = 8 const initialPeriod protocol.PacketNumber = 25 const maxPeriod protocol.PacketNumber = 300 const rep = 2500 periods := make([][]protocol.PacketNumber, rep) expectedPeriods := []protocol.PacketNumber{25, 50, 100, 200, 300, 300, 300} for i := range rep { png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) lastSkip := initialPN for len(periods[i]) < len(expectedPeriods) { skipNext, next := png.Pop() if skipNext { skipped := next + 1 require.Greater(t, skipped, lastSkip+1) periods[i] = append(periods[i], skipped-lastSkip-1) lastSkip = skipped } } } for j := range expectedPeriods { var average float64 for i := range rep { average += float64(periods[i][j]) / float64(len(periods)) } t.Logf("Period %d: %.2f (expected %d)\n", j, average, expectedPeriods[j]) require.InDelta(t, float64(expectedPeriods[j]+1), average, float64(max(protocol.PacketNumber(5), expectedPeriods[j]/10)), ) } } quic-go-0.59.0/internal/ackhandler/received_packet_handler.go000066400000000000000000000075221513066070600242340ustar00rootroot00000000000000package ackhandler import ( "fmt" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) type ReceivedPacketHandler struct { initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker appDataPackets appDataReceivedPacketTracker lowest1RTTPacket protocol.PacketNumber } func NewReceivedPacketHandler(logger utils.Logger) *ReceivedPacketHandler { return &ReceivedPacketHandler{ initialPackets: newReceivedPacketTracker(), handshakePackets: newReceivedPacketTracker(), appDataPackets: *newAppDataReceivedPacketTracker(logger), lowest1RTTPacket: protocol.InvalidPacketNumber, } } func (h *ReceivedPacketHandler) ReceivedPacket( pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime monotime.Time, ackEliciting bool, ) error { switch encLevel { case protocol.EncryptionInitial: return h.initialPackets.ReceivedPacket(pn, ecn, ackEliciting) case protocol.EncryptionHandshake: // The Handshake packet number space might already have been dropped as a result // of processing the CRYPTO frame that was contained in this packet. if h.handshakePackets == nil { return nil } return h.handshakePackets.ReceivedPacket(pn, ecn, ackEliciting) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) } return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption1RTT: if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) default: panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) } } func (h *ReceivedPacketHandler) IgnorePacketsBelow(pn protocol.PacketNumber) { h.appDataPackets.IgnoreBelow(pn) } func (h *ReceivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { //nolint:exhaustive // 1-RTT packet number space is never dropped. switch encLevel { case protocol.EncryptionInitial: h.initialPackets = nil case protocol.EncryptionHandshake: h.handshakePackets = nil case protocol.Encryption0RTT: // Nothing to do here. // If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted. default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } } func (h *ReceivedPacketHandler) GetAlarmTimeout() monotime.Time { return h.appDataPackets.GetAlarmTimeout() } func (h *ReceivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, now monotime.Time, onlyIfQueued bool) *wire.AckFrame { //nolint:exhaustive // 0-RTT packets can't contain ACK frames. switch encLevel { case protocol.EncryptionInitial: if h.initialPackets != nil { return h.initialPackets.GetAckFrame() } return nil case protocol.EncryptionHandshake: if h.handshakePackets != nil { return h.handshakePackets.GetAckFrame() } return nil case protocol.Encryption1RTT: return h.appDataPackets.GetAckFrame(now, onlyIfQueued) default: // 0-RTT packets can't contain ACK frames return nil } } func (h *ReceivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { switch encLevel { case protocol.EncryptionInitial: if h.initialPackets != nil { return h.initialPackets.IsPotentiallyDuplicate(pn) } case protocol.EncryptionHandshake: if h.handshakePackets != nil { return h.handshakePackets.IsPotentiallyDuplicate(pn) } case protocol.Encryption0RTT, protocol.Encryption1RTT: return h.appDataPackets.IsPotentiallyDuplicate(pn) } panic("unexpected encryption level") } quic-go-0.59.0/internal/ackhandler/received_packet_handler_test.go000066400000000000000000000137441513066070600252760ustar00rootroot00000000000000package ackhandler import ( "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestGenerateACKsForPacketNumberSpaces(t *testing.T) { handler := NewReceivedPacketHandler(utils.DefaultLogger) now := monotime.Now() sendTime := now.Add(-time.Second) require.NoError(t, handler.ReceivedPacket(2, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)) require.NoError(t, handler.ReceivedPacket(1, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)) require.NoError(t, handler.ReceivedPacket(5, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)) require.NoError(t, handler.ReceivedPacket(3, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)) require.NoError(t, handler.ReceivedPacket(2, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)) require.NoError(t, handler.ReceivedPacket(4, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)) // Initial initialAck := handler.GetAckFrame(protocol.EncryptionInitial, now, true) require.NotNil(t, initialAck) require.Equal(t, []wire.AckRange{{Smallest: 2, Largest: 3}}, initialAck.AckRanges) require.Zero(t, initialAck.DelayTime) require.EqualValues(t, 2, initialAck.ECT0) require.Zero(t, initialAck.ECT1) require.Zero(t, initialAck.ECNCE) // Handshake handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake, now, true) require.NotNil(t, handshakeAck) require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 2}}, handshakeAck.AckRanges) require.Zero(t, handshakeAck.DelayTime) require.Zero(t, handshakeAck.ECT0) require.EqualValues(t, 2, handshakeAck.ECT1) require.Zero(t, handshakeAck.ECNCE) // 1-RTT oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT, now, true) require.NotNil(t, oneRTTAck) require.Equal(t, []wire.AckRange{{Smallest: 4, Largest: 5}}, oneRTTAck.AckRanges) require.Equal(t, time.Second, oneRTTAck.DelayTime) require.Zero(t, oneRTTAck.ECT0) require.Zero(t, oneRTTAck.ECT1) require.EqualValues(t, 2, oneRTTAck.ECNCE) } func TestReceive0RTTAnd1RTT(t *testing.T) { handler := NewReceivedPacketHandler(utils.DefaultLogger) sendTime := monotime.Now().Add(-time.Second) require.NoError(t, handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)) require.NoError(t, handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) ack := handler.GetAckFrame(protocol.Encryption1RTT, monotime.Now(), true) require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 2, Largest: 3}}, ack.AckRanges) // 0-RTT packets with higher packet numbers than 1-RTT packets are rejected... require.Error(t, handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)) // ... but reordered 0-RTT packets are allowed require.NoError(t, handler.ReceivedPacket(1, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)) } func TestDropPackets(t *testing.T) { handler := NewReceivedPacketHandler(utils.DefaultLogger) sendTime := monotime.Now().Add(-time.Second) require.NoError(t, handler.ReceivedPacket(2, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)) require.NoError(t, handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)) require.NoError(t, handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) // Initial require.NotNil(t, handler.GetAckFrame(protocol.EncryptionInitial, monotime.Now(), true)) handler.DropPackets(protocol.EncryptionInitial) require.Nil(t, handler.GetAckFrame(protocol.EncryptionInitial, monotime.Now(), true)) // Handshake require.NotNil(t, handler.GetAckFrame(protocol.EncryptionHandshake, monotime.Now(), true)) handler.DropPackets(protocol.EncryptionHandshake) require.Nil(t, handler.GetAckFrame(protocol.EncryptionHandshake, monotime.Now(), true)) // 1-RTT require.NotNil(t, handler.GetAckFrame(protocol.Encryption1RTT, monotime.Now(), true)) // 0-RTT is a no-op handler.DropPackets(protocol.Encryption0RTT) } func TestAckRangePruning(t *testing.T) { handler := NewReceivedPacketHandler(utils.DefaultLogger) sendTime := monotime.Now() require.NoError(t, handler.ReceivedPacket(1, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) require.NoError(t, handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) ack := handler.GetAckFrame(protocol.Encryption1RTT, monotime.Now(), true) require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 2}}, ack.AckRanges) require.NoError(t, handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) handler.IgnorePacketsBelow(2) require.NoError(t, handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) ack = handler.GetAckFrame(protocol.Encryption1RTT, monotime.Now(), true) require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 2, Largest: 4}}, ack.AckRanges) } func TestPacketDuplicateDetection(t *testing.T) { handler := NewReceivedPacketHandler(utils.DefaultLogger) sendTime := monotime.Now() // 1-RTT is tested separately at the end encLevels := []protocol.EncryptionLevel{ protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, } for _, encLevel := range encLevels { // first, packet 3 is not a duplicate require.False(t, handler.IsPotentiallyDuplicate(3, encLevel)) require.NoError(t, handler.ReceivedPacket(3, protocol.ECNNon, encLevel, sendTime, true)) // now packet 3 is considered a duplicate require.True(t, handler.IsPotentiallyDuplicate(3, encLevel)) } // 1-RTT require.True(t, handler.IsPotentiallyDuplicate(3, protocol.Encryption1RTT)) require.False(t, handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)) require.NoError(t, handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) require.True(t, handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)) } quic-go-0.59.0/internal/ackhandler/received_packet_history.go000066400000000000000000000105661513066070600243220ustar00rootroot00000000000000package ackhandler import ( "iter" "slices" "github.com/quic-go/quic-go/internal/protocol" ) // interval is an interval from one PacketNumber to the other type interval struct { Start protocol.PacketNumber End protocol.PacketNumber } // The receivedPacketHistory stores if a packet number has already been received. // It generates ACK ranges which can be used to assemble an ACK frame. // It does not store packet contents. type receivedPacketHistory struct { ranges []interval // maximum length: protocol.MaxNumAckRanges deletedBelow protocol.PacketNumber } func newReceivedPacketHistory() *receivedPacketHistory { return &receivedPacketHistory{ deletedBelow: protocol.InvalidPacketNumber, } } // ReceivedPacket registers a packet with PacketNumber p and updates the ranges func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { // ignore delayed packets, if we already deleted the range if p < h.deletedBelow { return false } isNew := h.addToRanges(p) // Delete old ranges, if we're tracking too many of them. // This is a DoS defense against a peer that sends us too many gaps. if len(h.ranges) > protocol.MaxNumAckRanges { h.ranges = slices.Delete(h.ranges, 0, len(h.ranges)-protocol.MaxNumAckRanges) } return isNew } func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { if len(h.ranges) == 0 { h.ranges = append(h.ranges, interval{Start: p, End: p}) return true } for i := len(h.ranges) - 1; i >= 0; i-- { // p already included in an existing range. Nothing to do here if p >= h.ranges[i].Start && p <= h.ranges[i].End { return false } if h.ranges[i].End == p-1 { // extend a range at the end h.ranges[i].End = p return true } if h.ranges[i].Start == p+1 { // extend a range at the beginning h.ranges[i].Start = p if i > 0 && h.ranges[i-1].End+1 == h.ranges[i].Start { // merge two ranges h.ranges[i-1].End = h.ranges[i].End h.ranges = slices.Delete(h.ranges, i, i+1) } return true } // create a new range after the current one if p > h.ranges[i].End { h.ranges = slices.Insert(h.ranges, i+1, interval{Start: p, End: p}) return true } } // create a new range at the beginning h.ranges = slices.Insert(h.ranges, 0, interval{Start: p, End: p}) return true } // DeleteBelow deletes all entries below (but not including) p func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { if p < h.deletedBelow { return } h.deletedBelow = p if len(h.ranges) == 0 { return } idx := -1 for i := 0; i < len(h.ranges); i++ { if h.ranges[i].End < p { // delete a whole range idx = i } else if p > h.ranges[i].Start && p <= h.ranges[i].End { h.ranges[i].Start = p break } else { // no ranges affected. Nothing to do break } } if idx >= 0 { h.ranges = slices.Delete(h.ranges, 0, idx+1) } } // Backward returns an iterator over the ranges in reverse order func (h *receivedPacketHistory) Backward() iter.Seq[interval] { return func(yield func(interval) bool) { for i := len(h.ranges) - 1; i >= 0; i-- { if !yield(h.ranges[i]) { return } } } } func (h *receivedPacketHistory) HighestMissingUpTo(p protocol.PacketNumber) protocol.PacketNumber { if len(h.ranges) == 0 || (h.deletedBelow != protocol.InvalidPacketNumber && p < h.deletedBelow) { return protocol.InvalidPacketNumber } p = min(h.ranges[len(h.ranges)-1].End, p) for i := len(h.ranges) - 1; i >= 0; i-- { r := h.ranges[i] if p >= r.Start && p <= r.End { // p is contained in this range highest := r.Start - 1 // highest packet in the gap before this range if h.deletedBelow != protocol.InvalidPacketNumber && highest < h.deletedBelow { return protocol.InvalidPacketNumber } return highest } if i >= 1 && p > h.ranges[i-1].End && p <= r.Start { // p is in the gap between the previous range and this range return p } } return p } func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool { if p < h.deletedBelow { return true } // Iterating over the slices is faster than using a binary search (using slices.BinarySearchFunc). for i := len(h.ranges) - 1; i >= 0; i-- { if p > h.ranges[i].End { return false } if p <= h.ranges[i].End && p >= h.ranges[i].Start { return true } } return false } quic-go-0.59.0/internal/ackhandler/received_packet_history_test.go000066400000000000000000000216761513066070600253650ustar00rootroot00000000000000package ackhandler import ( "math/rand/v2" "slices" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestReceivedPacketHistorySingleRange(t *testing.T) { hist := newReceivedPacketHistory() require.True(t, hist.ReceivedPacket(4)) require.Equal(t, []interval{{Start: 4, End: 4}}, slices.Collect(hist.Backward())) // add a duplicate packet require.False(t, hist.ReceivedPacket(4)) require.Equal(t, []interval{{Start: 4, End: 4}}, slices.Collect(hist.Backward())) // add a few more packets to extend the range require.True(t, hist.ReceivedPacket(5)) require.True(t, hist.ReceivedPacket(6)) require.Equal(t, []interval{{Start: 4, End: 6}}, slices.Collect(hist.Backward())) // add a duplicate within this range require.False(t, hist.ReceivedPacket(5)) require.Equal(t, []interval{{Start: 4, End: 6}}, slices.Collect(hist.Backward())) // extend the range at the front require.True(t, hist.ReceivedPacket(3)) require.Equal(t, []interval{{Start: 3, End: 6}}, slices.Collect(hist.Backward())) } func TestReceivedPacketHistoryRanges(t *testing.T) { hist := newReceivedPacketHistory() require.Equal(t, protocol.InvalidPacketNumber, hist.HighestMissingUpTo(1000)) require.True(t, hist.ReceivedPacket(4)) require.Equal(t, protocol.PacketNumber(3), hist.HighestMissingUpTo(1000)) require.Equal(t, protocol.PacketNumber(3), hist.HighestMissingUpTo(4)) require.Equal(t, protocol.PacketNumber(3), hist.HighestMissingUpTo(3)) require.Equal(t, protocol.PacketNumber(2), hist.HighestMissingUpTo(2)) require.True(t, hist.ReceivedPacket(10)) require.Equal(t, protocol.PacketNumber(9), hist.HighestMissingUpTo(1000)) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 4, End: 4}, }, slices.Collect(hist.Backward())) // create a new range in the middle require.True(t, hist.ReceivedPacket(7)) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 7, End: 7}, {Start: 4, End: 4}, }, slices.Collect(hist.Backward())) // create a new range at the front require.True(t, hist.ReceivedPacket(1)) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 7, End: 7}, {Start: 4, End: 4}, {Start: 1, End: 1}, }, slices.Collect(hist.Backward())) // extend an existing range at the end require.True(t, hist.ReceivedPacket(8)) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 7, End: 8}, {Start: 4, End: 4}, {Start: 1, End: 1}, }, slices.Collect(hist.Backward())) // extend an existing range at the front require.True(t, hist.ReceivedPacket(6)) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 6, End: 8}, {Start: 4, End: 4}, {Start: 1, End: 1}, }, slices.Collect(hist.Backward())) // close a range require.True(t, hist.ReceivedPacket(9)) require.Equal(t, []interval{ {Start: 6, End: 10}, {Start: 4, End: 4}, {Start: 1, End: 1}, }, slices.Collect(hist.Backward())) } func TestReceivedPacketHistoryMaxNumAckRanges(t *testing.T) { hist := newReceivedPacketHistory() for i := range protocol.MaxNumAckRanges { require.True(t, hist.ReceivedPacket(protocol.PacketNumber(2*i))) } require.Len(t, hist.ranges, protocol.MaxNumAckRanges) require.Equal(t, interval{Start: 0, End: 0}, hist.ranges[0]) hist.ReceivedPacket(2*protocol.MaxNumAckRanges + 1000) // check that the oldest ACK range was deleted require.Len(t, hist.ranges, protocol.MaxNumAckRanges) require.Equal(t, interval{Start: 2, End: 2}, hist.ranges[0]) } func TestReceivedPacketHistoryDeleteBelow(t *testing.T) { hist := newReceivedPacketHistory() hist.DeleteBelow(2) require.Empty(t, slices.Collect(hist.Backward())) require.True(t, hist.ReceivedPacket(2)) require.True(t, hist.ReceivedPacket(4)) require.True(t, hist.ReceivedPacket(5)) require.True(t, hist.ReceivedPacket(6)) require.True(t, hist.ReceivedPacket(10)) require.Equal(t, protocol.PacketNumber(3), hist.HighestMissingUpTo(6)) hist.DeleteBelow(6) require.Equal(t, protocol.InvalidPacketNumber, hist.HighestMissingUpTo(6)) require.Equal(t, protocol.PacketNumber(9), hist.HighestMissingUpTo(10)) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 6, End: 6}, }, slices.Collect(hist.Backward())) // deleting from an existing range require.True(t, hist.ReceivedPacket(7)) require.True(t, hist.ReceivedPacket(8)) hist.DeleteBelow(7) require.Equal(t, []interval{ {Start: 10, End: 10}, {Start: 7, End: 8}, }, slices.Collect(hist.Backward())) // keep a one-packet range hist.DeleteBelow(10) require.Equal(t, []interval{{Start: 10, End: 10}}, slices.Collect(hist.Backward())) // delayed packets below deleted ranges are ignored require.False(t, hist.ReceivedPacket(5)) require.Equal(t, []interval{{Start: 10, End: 10}}, slices.Collect(hist.Backward())) } func TestReceivedPacketHistoryDuplicateDetection(t *testing.T) { hist := newReceivedPacketHistory() require.False(t, hist.IsPotentiallyDuplicate(5)) require.True(t, hist.ReceivedPacket(4)) require.True(t, hist.ReceivedPacket(5)) require.True(t, hist.ReceivedPacket(6)) require.True(t, hist.ReceivedPacket(8)) require.True(t, hist.ReceivedPacket(9)) require.False(t, hist.IsPotentiallyDuplicate(3)) require.True(t, hist.IsPotentiallyDuplicate(4)) require.True(t, hist.IsPotentiallyDuplicate(5)) require.True(t, hist.IsPotentiallyDuplicate(6)) require.False(t, hist.IsPotentiallyDuplicate(7)) require.True(t, hist.IsPotentiallyDuplicate(8)) require.True(t, hist.IsPotentiallyDuplicate(9)) require.False(t, hist.IsPotentiallyDuplicate(10)) // delete and check for potential duplicates hist.DeleteBelow(8) require.True(t, hist.IsPotentiallyDuplicate(7)) require.True(t, hist.IsPotentiallyDuplicate(8)) require.True(t, hist.IsPotentiallyDuplicate(9)) require.False(t, hist.IsPotentiallyDuplicate(10)) } func TestReceivedPacketHistoryRandomized(t *testing.T) { hist := newReceivedPacketHistory() packets := make(map[protocol.PacketNumber]struct{}) const num = 2 * protocol.MaxNumAckRanges numLostPackets := rand.IntN(protocol.MaxNumAckRanges) numRcvdPackets := num - numLostPackets for i := range num { packets[protocol.PacketNumber(i)] = struct{}{} } lostPackets := make([]protocol.PacketNumber, 0, numLostPackets) for len(lostPackets) < numLostPackets { p := protocol.PacketNumber(rand.IntN(num - 1)) // lose a random packet, but not the last one if _, ok := packets[p]; ok { lostPackets = append(lostPackets, p) delete(packets, p) } } slices.Sort(lostPackets) t.Logf("Losing packets: %v", lostPackets) ordered := make([]protocol.PacketNumber, 0, numRcvdPackets) for p := range packets { ordered = append(ordered, p) } rand.Shuffle(len(ordered), func(i, j int) { ordered[i], ordered[j] = ordered[j], ordered[i] }) t.Logf("Receiving packets: %v", ordered) for i, p := range ordered { require.True(t, hist.ReceivedPacket(p)) // sometimes receive a duplicate if i > 0 && rand.Int()%5 == 0 { require.False(t, hist.ReceivedPacket(ordered[rand.IntN(i)])) } } var counter int ackRanges := slices.Collect(hist.Backward()) t.Logf("ACK ranges: %v", ackRanges) require.LessOrEqual(t, len(ackRanges), numLostPackets+1) for _, ackRange := range ackRanges { for p := ackRange.Start; p <= ackRange.End; p++ { counter++ require.Contains(t, packets, p) } } require.Equal(t, numRcvdPackets, counter) deletedBelow := protocol.PacketNumber(rand.IntN(num * 2 / 3)) t.Logf("Deleting below %d", deletedBelow) hist.DeleteBelow(deletedBelow) for pn := range protocol.PacketNumber(num) { if pn < deletedBelow { require.Equal(t, protocol.InvalidPacketNumber, hist.HighestMissingUpTo(pn)) continue } expected := protocol.InvalidPacketNumber for _, lost := range lostPackets { if lost < deletedBelow { continue } if lost > pn { break } expected = lost } hm := hist.HighestMissingUpTo(pn) require.Equalf(t, expected, hm, "highest missing up to %d: %d", pn, hm) } } func BenchmarkHistoryReceiveSequentialPackets(b *testing.B) { hist := newReceivedPacketHistory() var pn protocol.PacketNumber for b.Loop() { hist.ReceivedPacket(pn) pn++ } } // Packets are received sequentially, with occasional gaps func BenchmarkHistoryReceiveCommonCase(b *testing.B) { hist := newReceivedPacketHistory() var pn protocol.PacketNumber for b.Loop() { hist.ReceivedPacket(pn) pn++ if pn%2000 == 0 { pn += 4 } } } func BenchmarkHistoryReceiveSequentialPacketsWithGaps(b *testing.B) { hist := newReceivedPacketHistory() var pn protocol.PacketNumber for b.Loop() { hist.ReceivedPacket(pn) pn += 2 } } func BenchmarkHistoryReceiveReversePacketsWithGaps(b *testing.B) { hist := newReceivedPacketHistory() for i := 0; i < b.N; i++ { hist.ReceivedPacket(protocol.PacketNumber(2 * (b.N - i))) } } func BenchmarkHistoryIsDuplicate(b *testing.B) { b.ReportAllocs() hist := newReceivedPacketHistory() var pn protocol.PacketNumber for range protocol.MaxNumAckRanges { for range 5 { hist.ReceivedPacket(pn) pn++ } pn += 5 // create a gap } var p protocol.PacketNumber for b.Loop() { hist.IsPotentiallyDuplicate(p % pn) p++ } } quic-go-0.59.0/internal/ackhandler/received_packet_tracker.go000066400000000000000000000153311513066070600242470ustar00rootroot00000000000000package ackhandler import ( "fmt" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) const reorderingThreshold = 1 // The receivedPacketTracker tracks packets for the Initial and Handshake packet number space. // Every received packet is acknowledged immediately. type receivedPacketTracker struct { ect0, ect1, ecnce uint64 packetHistory receivedPacketHistory lastAck *wire.AckFrame hasNewAck bool // true as soon as we received an ack-eliciting new packet } func newReceivedPacketTracker() *receivedPacketTracker { return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()} } func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, ackEliciting bool) error { if isNew := h.packetHistory.ReceivedPacket(pn); !isNew { return fmt.Errorf("receivedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) } //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE. switch ecn { case protocol.ECT0: h.ect0++ case protocol.ECT1: h.ect1++ case protocol.ECNCE: h.ecnce++ } if !ackEliciting { return nil } h.hasNewAck = true return nil } func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame { if !h.hasNewAck { return nil } // This function always returns the same ACK frame struct, filled with the most recent values. ack := h.lastAck if ack == nil { ack = &wire.AckFrame{} } ack.Reset() ack.ECT0 = h.ect0 ack.ECT1 = h.ect1 ack.ECNCE = h.ecnce for r := range h.packetHistory.Backward() { ack.AckRanges = append(ack.AckRanges, wire.AckRange{Smallest: r.Start, Largest: r.End}) } h.lastAck = ack h.hasNewAck = false return ack } func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { return h.packetHistory.IsPotentiallyDuplicate(pn) } // number of ack-eliciting packets received before sending an ACK const packetsBeforeAck = 2 // The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space. // It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached. type appDataReceivedPacketTracker struct { receivedPacketTracker largestObservedRcvdTime monotime.Time largestObserved protocol.PacketNumber ignoreBelow protocol.PacketNumber maxAckDelay time.Duration ackQueued bool // true if we need send a new ACK ackElicitingPacketsReceivedSinceLastAck int ackAlarm monotime.Time logger utils.Logger } func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker { h := &appDataReceivedPacketTracker{ receivedPacketTracker: *newReceivedPacketTracker(), maxAckDelay: protocol.MaxAckDelay, logger: logger, } return h } func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime monotime.Time, ackEliciting bool) error { if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, ackEliciting); err != nil { return err } if pn >= h.largestObserved { h.largestObserved = pn h.largestObservedRcvdTime = rcvTime } if !ackEliciting { return nil } h.ackElicitingPacketsReceivedSinceLastAck++ isMissing := h.isMissing(pn) if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) { h.ackQueued = true h.ackAlarm = 0 // cancel the ack alarm } if !h.ackQueued { // No ACK queued, but we'll need to acknowledge the packet after max_ack_delay. h.ackAlarm = rcvTime.Add(h.maxAckDelay) if h.logger.Debug() { h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) } } return nil } // IgnoreBelow sets a lower limit for acknowledging packets. // Packets with packet numbers smaller than p will not be acked. func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { if pn <= h.ignoreBelow { return } h.ignoreBelow = pn h.packetHistory.DeleteBelow(pn) if h.logger.Debug() { h.logger.Debugf("\tIgnoring all packets below %d.", pn) } } // isMissing says if a packet was reported missing in the last ACK. func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool { if h.lastAck == nil || p < h.ignoreBelow { return false } return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) } func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool { if h.lastAck == nil { return false } if h.largestObserved < reorderingThreshold { return false } highestMissing := h.packetHistory.HighestMissingUpTo(h.largestObserved - reorderingThreshold) if highestMissing == protocol.InvalidPacketNumber { return false } if highestMissing < h.lastAck.LargestAcked() { // the packet was already reported missing in the last ACK return false } return highestMissing > h.lastAck.LargestAcked()-reorderingThreshold } func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool { // Send an ACK if this packet was reported missing in an ACK sent before. // Ack decimation with reordering relies on the timer to send an ACK, but if // missing packets we reported in the previous ACK, send an ACK immediately. if wasMissing { if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) } return true } // send an ACK every 2 ack-eliciting packets if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck { if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) } return true } // queue an ACK if there are new missing packets to report if h.hasNewMissingPackets() { h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") return true } // queue an ACK if the packet was ECN-CE marked if ecn == protocol.ECNCE { h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.") return true } return false } func (h *appDataReceivedPacketTracker) GetAckFrame(now monotime.Time, onlyIfQueued bool) *wire.AckFrame { if onlyIfQueued && !h.ackQueued { if h.ackAlarm.IsZero() || h.ackAlarm.After(now) { return nil } if h.logger.Debug() && !h.ackAlarm.IsZero() { h.logger.Debugf("Sending ACK because the ACK timer expired.") } } ack := h.receivedPacketTracker.GetAckFrame() if ack == nil { return nil } ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime)) h.ackQueued = false h.ackAlarm = 0 h.ackElicitingPacketsReceivedSinceLastAck = 0 return ack } func (h *appDataReceivedPacketTracker) GetAlarmTimeout() monotime.Time { return h.ackAlarm } quic-go-0.59.0/internal/ackhandler/received_packet_tracker_test.go000066400000000000000000000150551513066070600253110ustar00rootroot00000000000000package ackhandler import ( "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestReceivedPacketTrackerGenerateACKs(t *testing.T) { tracker := newReceivedPacketTracker() require.NoError(t, tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, true)) ack := tracker.GetAckFrame() require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 3, Largest: 3}}, ack.AckRanges) require.Zero(t, ack.DelayTime) require.NoError(t, tracker.ReceivedPacket(protocol.PacketNumber(4), protocol.ECNNon, true)) ack = tracker.GetAckFrame() require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 3, Largest: 4}}, ack.AckRanges) require.Zero(t, ack.DelayTime) require.NoError(t, tracker.ReceivedPacket(protocol.PacketNumber(1), protocol.ECNNon, true)) ack = tracker.GetAckFrame() require.NotNil(t, ack) require.Equal(t, []wire.AckRange{ {Smallest: 3, Largest: 4}, {Smallest: 1, Largest: 1}, }, ack.AckRanges) require.Zero(t, ack.DelayTime) // non-ack-eliciting packets don't trigger ACKs require.NoError(t, tracker.ReceivedPacket(protocol.PacketNumber(10), protocol.ECNNon, false)) require.Nil(t, tracker.GetAckFrame()) require.NoError(t, tracker.ReceivedPacket(protocol.PacketNumber(11), protocol.ECNNon, true)) ack = tracker.GetAckFrame() require.NotNil(t, ack) require.Equal(t, []wire.AckRange{ {Smallest: 10, Largest: 11}, {Smallest: 3, Largest: 4}, {Smallest: 1, Largest: 1}, }, ack.AckRanges) } func TestAppDataReceivedPacketTrackerECN(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) require.NoError(t, tr.ReceivedPacket(0, protocol.ECT0, monotime.Now(), true)) pn := protocol.PacketNumber(1) for range 2 { require.NoError(t, tr.ReceivedPacket(pn, protocol.ECT1, monotime.Now(), true)) pn++ } for range 3 { require.NoError(t, tr.ReceivedPacket(pn, protocol.ECNCE, monotime.Now(), true)) pn++ } ack := tr.GetAckFrame(monotime.Now(), false) require.Equal(t, uint64(1), ack.ECT0) require.Equal(t, uint64(2), ack.ECT1) require.Equal(t, uint64(3), ack.ECNCE) } func TestAppDataReceivedPacketTrackerAckEverySecondPacket(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) require.Nil(t, tr.GetAckFrame(monotime.Now(), true)) for p := protocol.PacketNumber(1); p <= 20; p++ { require.NoError(t, tr.ReceivedPacket(p, protocol.ECNNon, monotime.Now(), true)) switch p % 2 { case 0: require.NotNil(t, tr.GetAckFrame(monotime.Now(), true)) case 1: require.Nil(t, tr.GetAckFrame(monotime.Now(), true)) } } } func TestAppDataReceivedPacketTrackerAlarmTimeout(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) now := monotime.Now() require.NoError(t, tr.ReceivedPacket(1, protocol.ECNNon, now, false)) require.Nil(t, tr.GetAckFrame(monotime.Now(), true)) require.Zero(t, tr.GetAlarmTimeout()) rcvTime := now.Add(10 * time.Millisecond) require.NoError(t, tr.ReceivedPacket(2, protocol.ECNNon, rcvTime, true)) require.Equal(t, rcvTime.Add(protocol.MaxAckDelay), tr.GetAlarmTimeout()) require.Nil(t, tr.GetAckFrame(monotime.Now(), true)) // no timeout after the ACK has been dequeued require.NotNil(t, tr.GetAckFrame(monotime.Now(), false)) require.Zero(t, tr.GetAlarmTimeout()) } func TestAppDataReceivedPacketTrackerQueuesECNCE(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) require.NoError(t, tr.ReceivedPacket(1, protocol.ECNCE, monotime.Now(), true)) ack := tr.GetAckFrame(monotime.Now(), true) require.NotNil(t, ack) require.Equal(t, protocol.PacketNumber(1), ack.LargestAcked()) require.EqualValues(t, 1, ack.ECNCE) } func TestAppDataReceivedPacketTrackerMissingPackets(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) now := monotime.Now() require.NoError(t, tr.ReceivedPacket(0, protocol.ECNNon, now, true)) require.Nil(t, tr.GetAckFrame(now, true)) require.NoError(t, tr.ReceivedPacket(5, protocol.ECNNon, now, true)) ack := tr.GetAckFrame(now, true) // ACK: 0 and 5, missing: 1, 2, 3, 4 require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 5, Largest: 5}, {Smallest: 0, Largest: 0}}, ack.AckRanges) // now receive one of the missing packets require.NoError(t, tr.ReceivedPacket(3, protocol.ECNNon, now, true)) ack = tr.GetAckFrame(now, true) require.NotNil(t, ack) require.Equal(t, []wire.AckRange{ {Smallest: 5, Largest: 5}, {Smallest: 3, Largest: 3}, {Smallest: 0, Largest: 0}, }, ack.AckRanges) require.NoError(t, tr.ReceivedPacket(6, protocol.ECNNon, now, true)) require.Nil(t, tr.GetAckFrame(now, true)) require.NoError(t, tr.ReceivedPacket(8, protocol.ECNNon, now, true)) require.NotNil(t, tr.GetAckFrame(now, true)) } func TestAppDataReceivedPacketTrackerDelayTime(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) now := monotime.Now() require.NoError(t, tr.ReceivedPacket(1, protocol.ECNNon, now, true)) require.NoError(t, tr.ReceivedPacket(2, protocol.ECNNon, now.Add(-1337*time.Millisecond), true)) ack := tr.GetAckFrame(now, true) require.NotNil(t, ack) require.Equal(t, 1337*time.Millisecond, ack.DelayTime) // don't use a negative delay time require.NoError(t, tr.ReceivedPacket(3, protocol.ECNNon, now.Add(time.Hour), true)) ack = tr.GetAckFrame(now, false) require.NotNil(t, ack) require.Zero(t, ack.DelayTime) } func TestAppDataReceivedPacketTrackerIgnoreBelow(t *testing.T) { tr := newAppDataReceivedPacketTracker(utils.DefaultLogger) tr.IgnoreBelow(4) // check that packets below 7 are considered duplicates require.True(t, tr.IsPotentiallyDuplicate(3)) require.False(t, tr.IsPotentiallyDuplicate(4)) for i := 5; i <= 10; i++ { require.NoError(t, tr.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, monotime.Now(), true)) } ack := tr.GetAckFrame(monotime.Now(), true) require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 5, Largest: 10}}, ack.AckRanges) tr.IgnoreBelow(7) require.NoError(t, tr.ReceivedPacket(11, protocol.ECNNon, monotime.Now(), true)) require.NoError(t, tr.ReceivedPacket(12, protocol.ECNNon, monotime.Now(), true)) ack = tr.GetAckFrame(monotime.Now(), true) require.NotNil(t, ack) require.Equal(t, []wire.AckRange{{Smallest: 7, Largest: 12}}, ack.AckRanges) // make sure that old packets are not accepted require.ErrorContains(t, tr.ReceivedPacket(4, protocol.ECNNon, monotime.Now(), true), "receivedPacketTracker BUG: ReceivedPacket called for old / duplicate packet 4", ) } quic-go-0.59.0/internal/ackhandler/send_mode.go000066400000000000000000000023311513066070600213500ustar00rootroot00000000000000package ackhandler import "fmt" // The SendMode says what kind of packets can be sent. type SendMode uint8 const ( // SendNone means that no packets should be sent SendNone SendMode = iota // SendAck means an ACK-only packet should be sent SendAck // SendPTOInitial means that an Initial probe packet should be sent SendPTOInitial // SendPTOHandshake means that a Handshake probe packet should be sent SendPTOHandshake // SendPTOAppData means that an Application data probe packet should be sent SendPTOAppData // SendPacingLimited means that the pacer doesn't allow sending of a packet right now, // but will do in a little while. // The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend. SendPacingLimited // SendAny means that any packet should be sent SendAny ) func (s SendMode) String() string { switch s { case SendNone: return "none" case SendAck: return "ack" case SendPTOInitial: return "pto (Initial)" case SendPTOHandshake: return "pto (Handshake)" case SendPTOAppData: return "pto (Application Data)" case SendAny: return "any" case SendPacingLimited: return "pacing limited" default: return fmt.Sprintf("invalid send mode: %d", s) } } quic-go-0.59.0/internal/ackhandler/send_mode_test.go000066400000000000000000000011071513066070600224070ustar00rootroot00000000000000package ackhandler import ( "testing" "github.com/stretchr/testify/require" ) func TestSendModeStringer(t *testing.T) { require.Equal(t, "none", SendNone.String()) require.Equal(t, "any", SendAny.String()) require.Equal(t, "pacing limited", SendPacingLimited.String()) require.Equal(t, "ack", SendAck.String()) require.Equal(t, "pto (Initial)", SendPTOInitial.String()) require.Equal(t, "pto (Handshake)", SendPTOHandshake.String()) require.Equal(t, "pto (Application Data)", SendPTOAppData.String()) require.Equal(t, "invalid send mode: 123", SendMode(123).String()) } quic-go-0.59.0/internal/ackhandler/sent_packet_handler.go000066400000000000000000001111411513066070600234100ustar00rootroot00000000000000package ackhandler import ( "errors" "fmt" "time" "github.com/quic-go/quic-go/internal/congestion" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) const ( // Maximum reordering in time space before time based loss detection considers a packet lost. // Specified as an RTT multiplier. timeThreshold = 9.0 / 8 // Maximum reordering in packets before packet threshold loss detection considers a packet lost. packetThreshold = 3 // Before validating the client's address, the server won't send more than 3x bytes than it received. amplificationFactor = 3 // We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet. minRTTAfterRetry = 5 * time.Millisecond // The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4. maxPTODuration = 60 * time.Second ) // Path probe packets are declared lost after this time. const pathProbePacketLossTimeout = time.Second type packetNumberSpace struct { history sentPacketHistory pns packetNumberGenerator lossTime monotime.Time lastAckElicitingPacketTime monotime.Time largestAcked protocol.PacketNumber largestSent protocol.PacketNumber } func newPacketNumberSpace(initialPN protocol.PacketNumber, isAppData bool) *packetNumberSpace { var pns packetNumberGenerator if isAppData { pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) } else { pns = newSequentialPacketNumberGenerator(initialPN) } return &packetNumberSpace{ history: *newSentPacketHistory(isAppData), pns: pns, largestSent: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, } } type alarmTimer struct { Time monotime.Time TimerType qlog.TimerType EncryptionLevel protocol.EncryptionLevel } type sentPacketHandler struct { initialPackets *packetNumberSpace handshakePackets *packetNumberSpace appDataPackets *packetNumberSpace lostPackets lostPacketTracker // only for application-data packet number space // send time of the largest acknowledged packet, across all packet number spaces largestAckedTime monotime.Time // Do we know that the peer completed address validation yet? // Always true for the server. peerCompletedAddressValidation bool bytesReceived protocol.ByteCount bytesSent protocol.ByteCount // Have we validated the peer's address yet? // Always true for the client. peerAddressValidated bool handshakeConfirmed bool ignorePacketsBelow func(protocol.PacketNumber) ackedPackets []packetWithPacketNumber // to avoid allocations in detectAndRemoveAckedPackets bytesInFlight protocol.ByteCount congestion congestion.SendAlgorithmWithDebugInfos rttStats *utils.RTTStats connStats *utils.ConnectionStats // The number of times a PTO has been sent without receiving an ack. ptoCount uint32 ptoMode SendMode // The number of PTO probe packets that should be sent. // Only applies to the application-data packet number space. numProbesToSend int // The alarm timeout alarm alarmTimer enableECN bool ecnTracker ecnHandler perspective protocol.Perspective qlogger qlogwriter.Recorder lastMetrics qlog.MetricsUpdated logger utils.Logger } var _ SentPacketHandler = &sentPacketHandler{} // clientAddressValidated indicates whether the address was validated beforehand by an address validation token. // If the address was validated, the amplification limit doesn't apply. It has no effect for a client. func NewSentPacketHandler( initialPN protocol.PacketNumber, initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, connStats *utils.ConnectionStats, clientAddressValidated bool, enableECN bool, ignorePacketsBelow func(protocol.PacketNumber), pers protocol.Perspective, qlogger qlogwriter.Recorder, logger utils.Logger, ) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, connStats, initialMaxDatagramSize, true, // use Reno qlogger, ) h := &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, initialPackets: newPacketNumberSpace(initialPN, false), handshakePackets: newPacketNumberSpace(0, false), appDataPackets: newPacketNumberSpace(0, true), lostPackets: *newLostPacketTracker(64), rttStats: rttStats, connStats: connStats, congestion: congestion, ignorePacketsBelow: ignorePacketsBelow, perspective: pers, qlogger: qlogger, logger: logger, } if enableECN { h.enableECN = true h.ecnTracker = newECNTracker(logger, qlogger) } return h } func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { if p.includedInBytesInFlight { if p.Length > h.bytesInFlight { panic("negative bytes_in_flight") } h.bytesInFlight -= p.Length p.includedInBytesInFlight = false } } func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now monotime.Time) { // The server won't await address validation after the handshake is confirmed. // This applies even if we didn't receive an ACK for a Handshake packet. if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { h.peerCompletedAddressValidation = true } // remove outstanding packets from bytes_in_flight if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { pnSpace := h.getPacketNumberSpace(encLevel) // We might already have dropped this packet number space. if pnSpace == nil { return } for _, p := range pnSpace.history.Packets() { h.removeFromBytesInFlight(p) } } // drop the packet history //nolint:exhaustive // Not every packet number space can be dropped. switch encLevel { case protocol.EncryptionInitial: h.initialPackets = nil case protocol.EncryptionHandshake: // Dropping the handshake packet number space means that the handshake is confirmed, // see section 4.9.2 of RFC 9001. h.handshakeConfirmed = true h.handshakePackets = nil case protocol.Encryption0RTT: // This function is only called when 0-RTT is rejected, // and not when the client drops 0-RTT keys when the handshake completes. // When 0-RTT is rejected, all application data sent so far becomes invalid. // Delete the packets from the history and remove them from bytes_in_flight. for pn, p := range h.appDataPackets.history.Packets() { if p.EncryptionLevel != protocol.Encryption0RTT { break } h.removeFromBytesInFlight(p) h.appDataPackets.history.Remove(pn) } default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } if h.qlogger != nil && h.ptoCount != 0 { h.qlogger.RecordEvent(qlog.PTOCountUpdated{PTOCount: 0}) } h.ptoCount = 0 h.numProbesToSend = 0 h.ptoMode = SendNone h.setLossDetectionTimer(now) } func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount, t monotime.Time) { h.connStats.BytesReceived.Add(uint64(n)) wasAmplificationLimit := h.isAmplificationLimited() h.bytesReceived += n if wasAmplificationLimit && !h.isAmplificationLimited() { h.setLossDetectionTimer(t) } } func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel, t monotime.Time) { h.connStats.PacketsReceived.Add(1) if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { h.peerAddressValidated = true h.setLossDetectionTimer(t) } } func (h *sentPacketHandler) packetsInFlight() int { packetsInFlight := h.appDataPackets.history.NumOutstanding() if h.handshakePackets != nil { packetsInFlight += h.handshakePackets.history.NumOutstanding() } if h.initialPackets != nil { packetsInFlight += h.initialPackets.history.NumOutstanding() } return packetsInFlight } func (h *sentPacketHandler) SentPacket( t monotime.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, isPathProbePacket bool, ) { h.bytesSent += size h.connStats.BytesSent.Add(uint64(size)) h.connStats.PacketsSent.Add(1) pnSpace := h.getPacketNumberSpace(encLevel) if h.logger.Debug() && (pnSpace.history.HasOutstandingPackets() || pnSpace.history.HasOutstandingPathProbes()) { for p := max(0, pnSpace.largestSent+1); p < pn; p++ { h.logger.Debugf("Skipping packet number %d", p) } } pnSpace.largestSent = pn p := getPacket() p.SendTime = t p.EncryptionLevel = encLevel p.Length = size p.Frames = frames p.LargestAcked = largestAcked p.StreamFrames = streamFrames p.IsPathMTUProbePacket = isPathMTUProbePacket p.isPathProbePacket = isPathProbePacket isAckEliciting := p.IsAckEliciting() if isPathProbePacket { pnSpace.history.SentPathProbePacket(pn, p) h.setLossDetectionTimer(t) return } if isAckEliciting { pnSpace.lastAckElicitingPacketTime = t h.bytesInFlight += size p.includedInBytesInFlight = true if h.numProbesToSend > 0 { h.numProbesToSend-- } } h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { h.ecnTracker.SentPacket(pn, ecn) } pnSpace.history.SentPacket(pn, p) if !isAckEliciting { if !h.peerCompletedAddressValidation { h.setLossDetectionTimer(t) } return } if h.qlogger != nil { h.qlogMetricsUpdated() } h.setLossDetectionTimer(t) } func (h *sentPacketHandler) qlogMetricsUpdated() { var metricsUpdatedEvent qlog.MetricsUpdated var updated bool if h.rttStats.HasMeasurement() { if h.lastMetrics.MinRTT != h.rttStats.MinRTT() { metricsUpdatedEvent.MinRTT = h.rttStats.MinRTT() h.lastMetrics.MinRTT = metricsUpdatedEvent.MinRTT updated = true } if h.lastMetrics.SmoothedRTT != h.rttStats.SmoothedRTT() { metricsUpdatedEvent.SmoothedRTT = h.rttStats.SmoothedRTT() h.lastMetrics.SmoothedRTT = metricsUpdatedEvent.SmoothedRTT updated = true } if h.lastMetrics.LatestRTT != h.rttStats.LatestRTT() { metricsUpdatedEvent.LatestRTT = h.rttStats.LatestRTT() h.lastMetrics.LatestRTT = metricsUpdatedEvent.LatestRTT updated = true } if h.lastMetrics.RTTVariance != h.rttStats.MeanDeviation() { metricsUpdatedEvent.RTTVariance = h.rttStats.MeanDeviation() h.lastMetrics.RTTVariance = metricsUpdatedEvent.RTTVariance updated = true } } if cwnd := h.congestion.GetCongestionWindow(); h.lastMetrics.CongestionWindow != int(cwnd) { metricsUpdatedEvent.CongestionWindow = int(cwnd) h.lastMetrics.CongestionWindow = metricsUpdatedEvent.CongestionWindow updated = true } if h.lastMetrics.BytesInFlight != int(h.bytesInFlight) { metricsUpdatedEvent.BytesInFlight = int(h.bytesInFlight) h.lastMetrics.BytesInFlight = metricsUpdatedEvent.BytesInFlight updated = true } packetsInFlight := h.packetsInFlight() if h.lastMetrics.PacketsInFlight != packetsInFlight { metricsUpdatedEvent.PacketsInFlight = packetsInFlight h.lastMetrics.PacketsInFlight = metricsUpdatedEvent.PacketsInFlight updated = true } if updated { h.qlogger.RecordEvent(metricsUpdatedEvent) } } func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { switch encLevel { case protocol.EncryptionInitial: return h.initialPackets case protocol.EncryptionHandshake: return h.handshakePackets case protocol.Encryption0RTT, protocol.Encryption1RTT: return h.appDataPackets default: panic("invalid packet number space") } } func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime monotime.Time) (bool /* contained 1-RTT packet */, error) { pnSpace := h.getPacketNumberSpace(encLevel) largestAcked := ack.LargestAcked() if largestAcked > pnSpace.largestSent { return false, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received ACK for an unsent packet", } } // Servers complete address validation when a protected packet is received. if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { h.peerCompletedAddressValidation = true h.logger.Debugf("Peer doesn't await address validation any longer.") // Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets. h.setLossDetectionTimer(rcvTime) } priorInFlight := h.bytesInFlight ackedPackets, hasAckEliciting, err := h.detectAndRemoveAckedPackets(ack, encLevel) if err != nil || len(ackedPackets) == 0 { return false, err } // update the RTT, if: // * the largest acked is newly acknowledged, AND // * at least one new ack-eliciting packet was acknowledged if len(ackedPackets) > 0 { if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() && !p.isPathProbePacket && hasAckEliciting { // don't use the ack delay for Initial and Handshake packets var ackDelay time.Duration if encLevel == protocol.Encryption1RTT { ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay()) } if h.largestAckedTime.IsZero() || !p.SendTime.Before(h.largestAckedTime) { h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } h.largestAckedTime = p.SendTime } h.congestion.MaybeExitSlowStart() } } // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) if congested { h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight) } } pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked) h.detectLostPackets(rcvTime, encLevel) if encLevel == protocol.Encryption1RTT { h.detectLostPathProbes(rcvTime) } var acked1RTTPacket bool for _, p := range ackedPackets { if p.includedInBytesInFlight { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } if p.EncryptionLevel == protocol.Encryption1RTT { acked1RTTPacket = true } h.removeFromBytesInFlight(p.packet) if !p.isPathProbePacket { putPacket(p.packet) } } // detect spurious losses for application data packets, if the ACK was not reordered if encLevel == protocol.Encryption1RTT && largestAcked == pnSpace.largestAcked { h.detectSpuriousLosses( ack, rcvTime.Add(-min(ack.DelayTime, h.rttStats.MaxAckDelay())), ) // clean up lost packet history h.lostPackets.DeleteBefore(rcvTime.Add(-3 * h.rttStats.PTO(false))) } // After this point, we must not use ackedPackets any longer! // We've already returned the buffers. ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side. clear(h.ackedPackets) // make sure the memory is released h.ackedPackets = h.ackedPackets[:0] // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { if h.qlogger != nil && h.ptoCount != 0 { h.qlogger.RecordEvent(qlog.PTOCountUpdated{PTOCount: 0}) } h.ptoCount = 0 } h.numProbesToSend = 0 if h.qlogger != nil { h.qlogMetricsUpdated() } h.setLossDetectionTimer(rcvTime) return acked1RTTPacket, nil } func (h *sentPacketHandler) detectSpuriousLosses(ack *wire.AckFrame, ackTime monotime.Time) { var maxPacketReordering protocol.PacketNumber var maxTimeReordering time.Duration ackRangeIdx := len(ack.AckRanges) - 1 var spuriousLosses []protocol.PacketNumber for pn, sendTime := range h.lostPackets.All() { ackRange := ack.AckRanges[ackRangeIdx] for pn > ackRange.Largest { // this should never happen, since detectSpuriousLosses is only called for ACKs that increase the largest acked if ackRangeIdx == 0 { break } ackRangeIdx-- ackRange = ack.AckRanges[ackRangeIdx] } if pn < ackRange.Smallest { continue } if pn <= ackRange.Largest { packetReordering := h.appDataPackets.history.Difference(ack.LargestAcked(), pn) timeReordering := ackTime.Sub(sendTime) maxPacketReordering = max(maxPacketReordering, packetReordering) maxTimeReordering = max(maxTimeReordering, timeReordering) if h.qlogger != nil { h.qlogger.RecordEvent(qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pn, PacketReordering: uint64(packetReordering), TimeReordering: timeReordering, }) } spuriousLosses = append(spuriousLosses, pn) } } for _, pn := range spuriousLosses { h.lostPackets.Delete(pn) } } // Packets are returned in ascending packet number order. func (h *sentPacketHandler) detectAndRemoveAckedPackets( ack *wire.AckFrame, encLevel protocol.EncryptionLevel, ) (_ []packetWithPacketNumber, hasAckEliciting bool, _ error) { if len(h.ackedPackets) > 0 { return nil, false, errors.New("ackhandler BUG: ackedPackets slice not empty") } pnSpace := h.getPacketNumberSpace(encLevel) if encLevel == protocol.Encryption1RTT { for p := range pnSpace.history.SkippedPackets() { if ack.AcksPacket(p) { return nil, false, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p, encLevel), } } } } var ackRangeIndex int lowestAcked := ack.LowestAcked() largestAcked := ack.LargestAcked() for pn, p := range pnSpace.history.Packets() { // ignore packets below the lowest acked if pn < lowestAcked { continue } if pn > largestAcked { break } if ack.HasMissingRanges() { ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] for pn > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 { ackRangeIndex++ ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] } if pn < ackRange.Smallest { // packet not contained in ACK range continue } if pn > ackRange.Largest { return nil, false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", pn, ackRange.Smallest, ackRange.Largest) } } if p.isPathProbePacket { probePacket := pnSpace.history.RemovePathProbe(pn) // the probe packet might already have been declared lost if probePacket != nil { h.ackedPackets = append(h.ackedPackets, packetWithPacketNumber{PacketNumber: pn, packet: probePacket}) } continue } if p.IsAckEliciting() { hasAckEliciting = true } h.ackedPackets = append(h.ackedPackets, packetWithPacketNumber{PacketNumber: pn, packet: p}) } if h.logger.Debug() && len(h.ackedPackets) > 0 { pns := make([]protocol.PacketNumber, len(h.ackedPackets)) for i, p := range h.ackedPackets { pns[i] = p.PacketNumber } h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns) } for _, p := range h.ackedPackets { if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT && h.ignorePacketsBelow != nil { h.ignorePacketsBelow(p.LargestAcked + 1) } for _, f := range p.Frames { if f.Handler != nil { f.Handler.OnAcked(f.Frame) } } for _, f := range p.StreamFrames { if f.Handler != nil { f.Handler.OnAcked(f.Frame) } } if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, false, err } } // TODO: add support for the transport:packets_acked qlog event return h.ackedPackets, hasAckEliciting, nil } func (h *sentPacketHandler) getLossTimeAndSpace() (monotime.Time, protocol.EncryptionLevel) { var encLevel protocol.EncryptionLevel var lossTime monotime.Time if h.initialPackets != nil { lossTime = h.initialPackets.lossTime encLevel = protocol.EncryptionInitial } if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) { lossTime = h.handshakePackets.lossTime encLevel = protocol.EncryptionHandshake } if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) { lossTime = h.appDataPackets.lossTime encLevel = protocol.Encryption1RTT } return lossTime, encLevel } func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration { pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount if pto > maxPTODuration || pto <= 0 { return maxPTODuration } return pto } // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime func (h *sentPacketHandler) getPTOTimeAndSpace(now monotime.Time) (pto monotime.Time, encLevel protocol.EncryptionLevel) { // We only send application data probe packets once the handshake is confirmed, // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { if h.peerCompletedAddressValidation { return } t := now.Add(h.getScaledPTO(false)) if h.initialPackets != nil { return t, protocol.EncryptionInitial } return t, protocol.EncryptionHandshake } if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() && !h.initialPackets.lastAckElicitingPacketTime.IsZero() { encLevel = protocol.EncryptionInitial if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { pto = t.Add(h.getScaledPTO(false)) } } if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.EncryptionHandshake } } if h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets() && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.Encryption1RTT } } return pto, encLevel } func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() { return true } if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() { return true } return false } func (h *sentPacketHandler) setLossDetectionTimer(now monotime.Time) { oldAlarm := h.alarm // only needed in case tracing is enabled newAlarm := h.lossDetectionTime(now) h.alarm = newAlarm hasAlarm := !newAlarm.Time.IsZero() if !hasAlarm && !oldAlarm.Time.IsZero() { h.logger.Debugf("Canceling loss detection timer.") if h.qlogger != nil { h.qlogger.RecordEvent(qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeCancelled, }) } } if h.qlogger != nil && hasAlarm && newAlarm != oldAlarm { h.qlogger.RecordEvent(qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeSet, TimerType: newAlarm.TimerType, EncLevel: newAlarm.EncryptionLevel, Time: newAlarm.Time.ToTime(), }) } } func (h *sentPacketHandler) lossDetectionTime(now monotime.Time) alarmTimer { // cancel the alarm if no packets are outstanding if h.peerCompletedAddressValidation && !h.hasOutstandingCryptoPackets() && !h.appDataPackets.history.HasOutstandingPackets() && !h.appDataPackets.history.HasOutstandingPathProbes() { return alarmTimer{} } // cancel the alarm if amplification limited if h.isAmplificationLimited() { return alarmTimer{} } var pathProbeLossTime monotime.Time if h.appDataPackets.history.HasOutstandingPathProbes() { if _, p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil { pathProbeLossTime = p.SendTime.Add(pathProbePacketLossTimeout) } } // early retransmit timer or time loss detection lossTime, encLevel := h.getLossTimeAndSpace() if !lossTime.IsZero() && (pathProbeLossTime.IsZero() || lossTime.Before(pathProbeLossTime)) { return alarmTimer{ Time: lossTime, TimerType: qlog.TimerTypeACK, EncryptionLevel: encLevel, } } ptoTime, encLevel := h.getPTOTimeAndSpace(now) if !ptoTime.IsZero() && (pathProbeLossTime.IsZero() || ptoTime.Before(pathProbeLossTime)) { return alarmTimer{ Time: ptoTime, TimerType: qlog.TimerTypePTO, EncryptionLevel: encLevel, } } if !pathProbeLossTime.IsZero() { return alarmTimer{ Time: pathProbeLossTime, TimerType: qlog.TimerTypePathProbe, EncryptionLevel: protocol.Encryption1RTT, } } return alarmTimer{} } func (h *sentPacketHandler) detectLostPathProbes(now monotime.Time) { if !h.appDataPackets.history.HasOutstandingPathProbes() { return } lossTime := now.Add(-pathProbePacketLossTimeout) // RemovePathProbe cannot be called while iterating. var lostPathProbes []packetWithPacketNumber for pn, p := range h.appDataPackets.history.PathProbes() { if !p.SendTime.After(lossTime) { lostPathProbes = append(lostPathProbes, packetWithPacketNumber{PacketNumber: pn, packet: p}) } } for _, p := range lostPathProbes { for _, f := range p.Frames { f.Handler.OnLost(f.Frame) } h.appDataPackets.history.RemovePathProbe(p.PacketNumber) } } func (h *sentPacketHandler) detectLostPackets(now monotime.Time, encLevel protocol.EncryptionLevel) { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = 0 maxRTT := float64(max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) lossDelay := time.Duration(timeThreshold * maxRTT) // Minimum time of granularity before packets are deemed lost. lossDelay = max(lossDelay, protocol.TimerGranularity) // Packets sent before this time are deemed lost. lostSendTime := now.Add(-lossDelay) priorInFlight := h.bytesInFlight for pn, p := range pnSpace.history.Packets() { if pn > pnSpace.largestAcked { break } var packetLost bool if !p.SendTime.After(lostSendTime) { packetLost = true if !p.isPathProbePacket && p.IsAckEliciting() { if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (time threshold)", pn) } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.PacketLost{ Header: qlog.PacketHeader{ PacketType: qlog.EncryptionLevelToPacketType(p.EncryptionLevel), PacketNumber: pn, }, Trigger: qlog.PacketLossTimeThreshold, }) } } } else if pnSpace.history.Difference(pnSpace.largestAcked, pn) >= packetThreshold { packetLost = true if !p.isPathProbePacket && p.IsAckEliciting() { if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (reordering threshold)", pn) } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.PacketLost{ Header: qlog.PacketHeader{ PacketType: qlog.EncryptionLevelToPacketType(p.EncryptionLevel), PacketNumber: pn, }, Trigger: qlog.PacketLossReorderingThreshold, }) } } } else if pnSpace.lossTime.IsZero() { // Note: This conditional is only entered once per call lossTime := p.SendTime.Add(lossDelay) if h.logger.Debug() { h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", pn, encLevel, lossDelay, lossTime) } pnSpace.lossTime = lossTime } if packetLost { if encLevel == protocol.Encryption0RTT || encLevel == protocol.Encryption1RTT { h.lostPackets.Add(pn, p.SendTime) } pnSpace.history.DeclareLost(pn) if !p.isPathProbePacket && p.IsAckEliciting() { // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) if !p.IsPathMTUProbePacket { h.congestion.OnCongestionEvent(pn, p.Length, priorInFlight) } if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { h.ecnTracker.LostPacket(pn) } } } } } func (h *sentPacketHandler) OnLossDetectionTimeout(now monotime.Time) error { defer h.setLossDetectionTimer(now) if h.handshakeConfirmed { h.detectLostPathProbes(now) } earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeExpired, TimerType: qlog.TimerTypeACK, EncLevel: encLevel, }) } // Early retransmit or time loss detection h.detectLostPackets(now, encLevel) return nil } // PTO // When all outstanding are acknowledged, the alarm is canceled in setLossDetectionTimer. // However, there's no way to reset the timer in the connection. // When OnLossDetectionTimeout is called, we therefore need to make sure that there are // actually packets outstanding. if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { h.ptoCount++ h.numProbesToSend++ if h.initialPackets != nil { h.ptoMode = SendPTOInitial } else if h.handshakePackets != nil { h.ptoMode = SendPTOHandshake } else { return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped") } return nil } ptoTime, encLevel := h.getPTOTimeAndSpace(now) if ptoTime.IsZero() { return nil } ps := h.getPacketNumberSpace(encLevel) if !ps.history.HasOutstandingPackets() && !ps.history.HasOutstandingPathProbes() && !h.peerCompletedAddressValidation { return nil } h.ptoCount++ if h.logger.Debug() { h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeExpired, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, }) h.qlogger.RecordEvent(qlog.PTOCountUpdated{PTOCount: h.ptoCount}) } h.numProbesToSend += 2 //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. switch encLevel { case protocol.EncryptionInitial: h.ptoMode = SendPTOInitial case protocol.EncryptionHandshake: h.ptoMode = SendPTOHandshake case protocol.Encryption1RTT: // skip a packet number in order to elicit an immediate ACK pn := h.PopPacketNumber(protocol.Encryption1RTT) h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn) h.ptoMode = SendPTOAppData default: return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) } return nil } func (h *sentPacketHandler) GetLossDetectionTimeout() monotime.Time { return h.alarm.Time } func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { if !h.enableECN { return protocol.ECNUnsupported } if !isShortHeaderPacket { return protocol.ECNNon } return h.ecnTracker.Mode() } func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) pn := pnSpace.pns.Peek() // See section 17.1 of RFC 9000. return pn, protocol.PacketNumberLengthForHeader(pn, pnSpace.largestAcked) } func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { pnSpace := h.getPacketNumberSpace(encLevel) skipped, pn := pnSpace.pns.Pop() if skipped { skippedPN := pn - 1 pnSpace.history.SkippedPacket(skippedPN) if h.logger.Debug() { h.logger.Debugf("Skipping packet number %d", skippedPN) } } return pn } func (h *sentPacketHandler) SendMode(now monotime.Time) SendMode { numTrackedPackets := h.appDataPackets.history.Len() if h.initialPackets != nil { numTrackedPackets += h.initialPackets.history.Len() } if h.handshakePackets != nil { numTrackedPackets += h.handshakePackets.history.Len() } if h.isAmplificationLimited() { h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent) return SendNone } // Don't send any packets if we're keeping track of the maximum number of packets. // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, // we will stop sending out new data when reaching MaxOutstandingSentPackets, // but still allow sending of retransmissions and ACKs. if numTrackedPackets >= protocol.MaxTrackedSentPackets { if h.logger.Debug() { h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) } return SendNone } if h.numProbesToSend > 0 { return h.ptoMode } // Only send ACKs if we're congestion limited. if !h.congestion.CanSend(h.bytesInFlight) { if h.logger.Debug() { h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow()) } return SendAck } if numTrackedPackets >= protocol.MaxOutstandingSentPackets { if h.logger.Debug() { h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) } return SendAck } if !h.congestion.HasPacingBudget(now) { return SendPacingLimited } return SendAny } func (h *sentPacketHandler) TimeUntilSend() monotime.Time { return h.congestion.TimeUntilSend(h.bytesInFlight) } func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { h.congestion.SetMaxDatagramSize(s) } func (h *sentPacketHandler) isAmplificationLimited() bool { if h.peerAddressValidated { return false } return h.bytesSent >= amplificationFactor*h.bytesReceived } func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { pnSpace := h.getPacketNumberSpace(encLevel) pn, p := pnSpace.history.FirstOutstanding() if p == nil { return false } // TODO: don't declare the packet lost here. // Keep track of acknowledged frames instead. // Call DeclareLost before queueFramesForRetransmission, which clears the packet's frames. pnSpace.history.DeclareLost(pn) h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) return true } func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { if len(p.Frames) == 0 && len(p.StreamFrames) == 0 { panic("no frames") } for _, f := range p.Frames { if f.Handler != nil { f.Handler.OnLost(f.Frame) } } for _, f := range p.StreamFrames { if f.Handler != nil { f.Handler.OnLost(f.Frame) } } p.StreamFrames = nil p.Frames = nil } func (h *sentPacketHandler) ResetForRetry(now monotime.Time) { h.bytesInFlight = 0 var firstPacketSendTime monotime.Time for _, p := range h.initialPackets.history.Packets() { if firstPacketSendTime.IsZero() { firstPacketSendTime = p.SendTime } if p.IsAckEliciting() { h.queueFramesForRetransmission(p) } } // All application data packets sent at this point are 0-RTT packets. // In the case of a Retry, we can assume that the server dropped all of them. for _, p := range h.appDataPackets.history.Packets() { if p.IsAckEliciting() { h.queueFramesForRetransmission(p) } } // Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial. // Otherwise, we don't know which Initial the Retry was sent in response to. if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } if h.qlogger != nil { h.qlogMetricsUpdated() } } h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false) h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true) oldAlarm := h.alarm h.alarm = alarmTimer{} if h.qlogger != nil { h.qlogger.RecordEvent(qlog.PTOCountUpdated{PTOCount: 0}) if !oldAlarm.Time.IsZero() { h.qlogger.RecordEvent(qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeCancelled, }) } } h.ptoCount = 0 } func (h *sentPacketHandler) MigratedPath(now monotime.Time, initialMaxDatagramSize protocol.ByteCount) { h.rttStats.ResetForPathMigration() for pn, p := range h.appDataPackets.history.Packets() { h.appDataPackets.history.DeclareLost(pn) if !p.isPathProbePacket { h.removeFromBytesInFlight(p) if p.IsAckEliciting() { h.queueFramesForRetransmission(p) } } } for pn := range h.appDataPackets.history.PathProbes() { h.appDataPackets.history.RemovePathProbe(pn) } h.congestion = congestion.NewCubicSender( congestion.DefaultClock{}, h.rttStats, h.connStats, initialMaxDatagramSize, true, // use Reno h.qlogger, ) h.setLossDetectionTimer(now) } quic-go-0.59.0/internal/ackhandler/sent_packet_handler_test.go000066400000000000000000001723021513066070600244550ustar00rootroot00000000000000package ackhandler import ( "encoding/binary" "fmt" "math/rand/v2" "slices" "testing" "time" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type customFrameHandler struct { onLost, onAcked func(wire.Frame) } func (h *customFrameHandler) OnLost(f wire.Frame) { if h.onLost != nil { h.onLost(f) } } func (h *customFrameHandler) OnAcked(f wire.Frame) { if h.onAcked != nil { h.onAcked(f) } } type packetTracker struct { Acked []protocol.PacketNumber Lost []protocol.PacketNumber } func (t *packetTracker) Reset() { t.Acked = nil t.Lost = nil } func (t *packetTracker) NewPingFrame(pn protocol.PacketNumber) Frame { return Frame{ Frame: &wire.PingFrame{}, Handler: &customFrameHandler{ onAcked: func(wire.Frame) { t.Acked = append(t.Acked, pn) }, onLost: func(wire.Frame) { t.Lost = append(t.Lost, pn) }, }, } } func (h *sentPacketHandler) getBytesInFlight() protocol.ByteCount { return h.bytesInFlight } func ackRanges(pns ...protocol.PacketNumber) []wire.AckRange { return appendAckRanges(nil, pns...) } func appendAckRanges(ranges []wire.AckRange, pns ...protocol.PacketNumber) []wire.AckRange { if len(pns) == 0 { return ranges } slices.Sort(pns) slices.Reverse(pns) start := pns[0] for i := 1; i < len(pns); i++ { if pns[i-1]-pns[i] > 1 { ranges = append(ranges, wire.AckRange{Smallest: pns[i-1], Largest: start}) start = pns[i] } } return append(ranges, wire.AckRange{Smallest: pns[len(pns)-1], Largest: start}) } func TestAckRanges(t *testing.T) { require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 1}}, ackRanges(1)) require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 2}}, ackRanges(1, 2)) require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 3}}, ackRanges(1, 2, 3)) require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 3}}, ackRanges(3, 2, 1)) require.Equal(t, []wire.AckRange{{Smallest: 1, Largest: 3}}, ackRanges(1, 3, 2)) require.Equal(t, []wire.AckRange{{Smallest: 3, Largest: 3}, {Smallest: 1, Largest: 1}}, ackRanges(1, 3)) require.Equal(t, []wire.AckRange{{Smallest: 3, Largest: 4}, {Smallest: 1, Largest: 1}}, ackRanges(1, 3, 4)) require.Equal(t, []wire.AckRange{{Smallest: 5, Largest: 6}, {Smallest: 0, Largest: 2}}, ackRanges(0, 1, 2, 5, 6)) } func TestSentPacketHandlerSendAndAcknowledge(t *testing.T) { t.Run("Initial", func(t *testing.T) { testSentPacketHandlerSendAndAcknowledge(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testSentPacketHandlerSendAndAcknowledge(t, protocol.EncryptionHandshake) }) t.Run("1-RTT", func(t *testing.T) { testSentPacketHandlerSendAndAcknowledge(t, protocol.Encryption1RTT) }) } func testSentPacketHandlerSendAndAcknowledge(t *testing.T, encLevel protocol.EncryptionLevel) { sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, false, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) var packets packetTracker var pns []protocol.PacketNumber now := monotime.Now() for i := range 10 { e := encLevel // also send some 0-RTT packets to make sure they're acknowledged in the same packet number space if encLevel == protocol.Encryption1RTT && i < 5 { e = protocol.Encryption0RTT } pn := sph.PopPacketNumber(e) sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, e, protocol.ECNNon, 1200, false, false) pns = append(pns, pn) } _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[0], pns[1], pns[2], pns[3], pns[4], pns[7], pns[8], pns[9])}, encLevel, monotime.Now(), ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[0], pns[1], pns[2], pns[3], pns[4], pns[7], pns[8], pns[9]}, packets.Acked) // ACKs that don't acknowledge new packets are ok _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[1], pns[2], pns[3])}, encLevel, monotime.Now(), ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[0], pns[1], pns[2], pns[3], pns[4], pns[7], pns[8], pns[9]}, packets.Acked) // ACKs that don't acknowledge packets that we didn't send are not ok _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[7], pns[8], pns[9], pns[9]+1)}, encLevel, monotime.Now(), ) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) require.ErrorContains(t, err, "received ACK for an unsent packet") } func TestSentPacketHandlerAcknowledgeSkippedPacket(t *testing.T) { sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, false, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) now := monotime.Now() lastPN := protocol.InvalidPacketNumber skippedPN := protocol.InvalidPacketNumber for { pn, _ := sph.PeekPacketNumber(protocol.Encryption1RTT) require.Equal(t, pn, sph.PopPacketNumber(protocol.Encryption1RTT)) if pn > lastPN+1 { skippedPN = pn - 1 } if pn >= 1e6 { t.Fatal("expected a skipped packet number") } sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, false) lastPN = pn if skippedPN != protocol.InvalidPacketNumber { break } } _, err := sph.ReceivedAck(&wire.AckFrame{ AckRanges: []wire.AckRange{{Smallest: 0, Largest: lastPN}}, }, protocol.Encryption1RTT, monotime.Now()) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) require.ErrorContains(t, err, fmt.Sprintf("received an ACK for skipped packet number: %d (1-RTT)", skippedPN)) } func TestSentPacketHandlerRTTAckEliciting(t *testing.T) { var eventRecorder events.Recorder rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, false, false, nil, protocol.PerspectiveClient, &eventRecorder, utils.DefaultLogger, ) getPacketsInFlight := func() int { evs := eventRecorder.Events(qlog.MetricsUpdated{}) return evs[len(evs)-1].(qlog.MetricsUpdated).PacketsInFlight } getBytesInFlight := func() int { evs := eventRecorder.Events(qlog.MetricsUpdated{}) return evs[len(evs)-1].(qlog.MetricsUpdated).BytesInFlight } sendPacket := func(t *testing.T, ti monotime.Time, size protocol.ByteCount, ackEliciting bool) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.Encryption1RTT) var frames []Frame if ackEliciting { frames = []Frame{{Frame: &wire.PingFrame{}}} } sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, frames, protocol.Encryption1RTT, protocol.ECNNon, size, false, false) return pn } ackPackets := func(t *testing.T, ti monotime.Time, pns ...protocol.PacketNumber) { t.Helper() _, err := sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns...)}, protocol.Encryption1RTT, ti) require.NoError(t, err) } now := monotime.Now() pn1 := sendPacket(t, now, 1200, true) require.Equal(t, 1, getPacketsInFlight()) require.Equal(t, 1200, getBytesInFlight()) pn2 := sendPacket(t, now, 1100, false) // Sending a non-ack-eliciting packet doesn't change bytes or packets in flight. // Non-ack-eliciting packets are not included in congestion control. require.Equal(t, 1, getPacketsInFlight()) require.Equal(t, 1200, getBytesInFlight()) pn3 := sendPacket(t, now, 1000, true) require.Equal(t, 2, getPacketsInFlight()) require.Equal(t, 2200, getBytesInFlight()) // the RTT is recorded, since the largest acknowledged packet is ack-eliciting now = now.Add(200 * time.Millisecond) ackPackets(t, now, pn1, pn2, pn3) require.Equal(t, 200*time.Millisecond, rttStats.LatestRTT()) require.Zero(t, getPacketsInFlight()) require.Zero(t, getBytesInFlight()) pn4 := sendPacket(t, now, 1200, false) // non-ack-eliciting packets don't trigger metrics updates require.Zero(t, getPacketsInFlight()) require.Zero(t, getBytesInFlight()) pn5 := sendPacket(t, now, 500, false) require.Zero(t, getPacketsInFlight()) require.Zero(t, getBytesInFlight()) now = now.Add(500 * time.Millisecond) // only non-ack-eliciting packets are newly acknowledged, so the RTT is not updated ackPackets(t, now, pn2, pn3, pn4, pn5) require.Equal(t, 200*time.Millisecond, rttStats.LatestRTT()) pn6 := sendPacket(t, now, 1400, true) require.Equal(t, 1, getPacketsInFlight()) require.Equal(t, 1400, getBytesInFlight()) pn7 := sendPacket(t, now, 1100, false) // non-ack-eliciting packet doesn't change metrics require.Equal(t, 1, getPacketsInFlight()) require.Equal(t, 1400, getBytesInFlight()) now = now.Add(800 * time.Millisecond) // largest acknowledged packet is not ack-eliciting, but one new ack-eliciting // packet was acknowledged, so the RTT is updated ackPackets(t, now, pn6, pn7) require.Equal(t, 800*time.Millisecond, rttStats.LatestRTT()) require.Zero(t, getPacketsInFlight()) require.Zero(t, getBytesInFlight()) } func TestSentPacketHandlerRTTAcrossPacketNumberSpaces(t *testing.T) { rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, false, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) sendPacket := func(t *testing.T, ti monotime.Time, encLevel protocol.EncryptionLevel) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(encLevel) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, encLevel, protocol.ECNNon, 1200, false, false) return pn } ackPackets := func(t *testing.T, ti monotime.Time, encLevel protocol.EncryptionLevel, pns ...protocol.PacketNumber) { t.Helper() _, err := sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns...)}, encLevel, ti) require.NoError(t, err) } now := monotime.Now() initial1 := sendPacket(t, now, protocol.EncryptionInitial) handshake1 := sendPacket(t, now.Add(time.Second), protocol.EncryptionHandshake) initial2 := sendPacket(t, now.Add(2*time.Second), protocol.EncryptionInitial) handshake2 := sendPacket(t, now.Add(2*time.Second), protocol.EncryptionHandshake) ackPackets(t, now.Add(3*time.Second), protocol.EncryptionInitial, initial1, initial2) require.Equal(t, time.Second, rttStats.LatestRTT()) // No RTT measurement, since the second initial packet was sent after the first handshake packet. ackPackets(t, now.Add(4*time.Second), protocol.EncryptionHandshake, handshake1) require.Equal(t, time.Second, rttStats.LatestRTT()) // This causes an RTT measurement, since the second handshake packet was sent last. ackPackets(t, now.Add(5*time.Second), protocol.EncryptionHandshake, handshake1, handshake2) require.Equal(t, 3*time.Second, rttStats.LatestRTT()) } func TestSentPacketHandlerRTTAckDelays(t *testing.T) { t.Run("Initial", func(t *testing.T) { testSentPacketHandlerRTTAckDelays(t, protocol.EncryptionInitial, false) }) t.Run("Handshake", func(t *testing.T) { testSentPacketHandlerRTTAckDelays(t, protocol.EncryptionHandshake, false) }) t.Run("1-RTT", func(t *testing.T) { testSentPacketHandlerRTTAckDelays(t, protocol.Encryption1RTT, true) }) } func testSentPacketHandlerRTTAckDelays(t *testing.T, encLevel protocol.EncryptionLevel, usesAckDelay bool) { expectedRTTStats := utils.NewRTTStats() expectedRTTStats.SetMaxAckDelay(time.Second) rttStats := utils.NewRTTStats() rttStats.SetMaxAckDelay(time.Second) sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, false, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) sendPacket := func(t *testing.T, ti monotime.Time) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(encLevel) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, encLevel, protocol.ECNNon, 1200, false, false) return pn } ackPacket := func(pn protocol.PacketNumber, ti monotime.Time, d time.Duration) { t.Helper() _, err := sph.ReceivedAck(&wire.AckFrame{DelayTime: d, AckRanges: ackRanges(pn)}, encLevel, ti) require.NoError(t, err) } var packets []protocol.PacketNumber now := monotime.Now() // send some packets and receive ACKs with 0 ack delay for range 5 { packets = append(packets, sendPacket(t, now)) } for i := range 5 { expectedRTTStats.UpdateRTT(time.Duration(i+1)*time.Second, 0) now = now.Add(time.Second) ackPacket(packets[i], now, 0) require.Equal(t, expectedRTTStats.SmoothedRTT(), rttStats.SmoothedRTT()) require.Equal(t, time.Second, rttStats.MinRTT()) require.Equal(t, time.Duration(i+1)*time.Second, rttStats.LatestRTT()) } packets = packets[:0] // send some more packets and receive ACKs with non-zero ack delay for range 5 { packets = append(packets, sendPacket(t, now)) } expectedRTTStatsNoAckDelay := expectedRTTStats.Clone() for i := range 5 { const ackDelay = 500 * time.Millisecond expectedRTTStats.UpdateRTT(time.Duration(i+1)*time.Second, ackDelay) expectedRTTStatsNoAckDelay.UpdateRTT(time.Duration(i+1)*time.Second, 0) now = now.Add(time.Second) ackPacket(packets[i], now, ackDelay) if usesAckDelay { require.Equal(t, expectedRTTStats.SmoothedRTT(), rttStats.SmoothedRTT()) } else { require.Equal(t, expectedRTTStatsNoAckDelay.SmoothedRTT(), rttStats.SmoothedRTT()) } } packets = packets[:0] // make sure that taking ack delay into account actually changes the RTT, // otherwise the test is not meaningful require.NotEqual(t, expectedRTTStats.SmoothedRTT(), expectedRTTStatsNoAckDelay.SmoothedRTT()) // Send two more packets, and acknowledge them in opposite order. // This tests that the RTT is updated even if the ACK doesn't increase the largest acked. packets = append(packets, sendPacket(t, now)) packets = append(packets, sendPacket(t, now)) ackPacket(packets[1], now.Add(time.Second), 0) rtt := rttStats.SmoothedRTT() ackPacket(packets[0], now.Add(10*time.Second), 0) require.NotEqual(t, rtt, rttStats.SmoothedRTT()) // Send one more packet, and send where the largest acked is acknowledged twice. pn := sendPacket(t, now) ackPacket(pn, now.Add(time.Second), 0) rtt = rttStats.SmoothedRTT() ackPacket(pn, now.Add(10*time.Second), 0) require.Equal(t, rtt, rttStats.SmoothedRTT()) } func TestSentPacketHandlerAmplificationLimitServer(t *testing.T) { t.Run("address validated", func(t *testing.T) { testSentPacketHandlerAmplificationLimitServer(t, true) }) t.Run("address not validated", func(t *testing.T) { testSentPacketHandlerAmplificationLimitServer(t, false) }) } func testSentPacketHandlerAmplificationLimitServer(t *testing.T, addressValidated bool) { sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, addressValidated, false, nil, protocol.PerspectiveServer, nil, utils.DefaultLogger, ) if addressValidated { require.Equal(t, SendAny, sph.SendMode(monotime.Now())) return } // no data received yet, so we can't send any packet yet require.Equal(t, SendNone, sph.SendMode(monotime.Now())) require.Zero(t, sph.GetLossDetectionTimeout()) // Receive 1000 bytes from the client. // As long as we haven't sent out 3x the amount of bytes received, we can send out new packets, // even if we go above the 3x limit by sending the last packet. sph.ReceivedBytes(1000, monotime.Now()) for i := 0; i < 4; i++ { require.Equal(t, SendAny, sph.SendMode(monotime.Now())) pn := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(monotime.Now(), pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, protocol.EncryptionInitial, protocol.ECNNon, 999, false, false) if i != 3 { require.NotZero(t, sph.GetLossDetectionTimeout()) } } require.Equal(t, SendNone, sph.SendMode(monotime.Now())) // no need to set a loss detection timer, as we're blocked by the amplification limit require.Zero(t, sph.GetLossDetectionTimeout()) // receiving more data allows us to send out more packets sph.ReceivedBytes(1000, monotime.Now()) require.NotZero(t, sph.GetLossDetectionTimeout()) for i := 0; i < 3; i++ { require.Equal(t, SendAny, sph.SendMode(monotime.Now())) pn := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(monotime.Now(), pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, protocol.EncryptionInitial, protocol.ECNNon, 1000, false, false) } require.Equal(t, SendNone, sph.SendMode(monotime.Now())) require.Zero(t, sph.GetLossDetectionTimeout()) // receiving an Initial packet doesn't validate the client's address sph.ReceivedPacket(protocol.EncryptionInitial, monotime.Now()) require.Equal(t, SendNone, sph.SendMode(monotime.Now())) require.Zero(t, sph.GetLossDetectionTimeout()) // receiving a Handshake packet validates the client's address sph.ReceivedPacket(protocol.EncryptionHandshake, monotime.Now()) require.Equal(t, SendAny, sph.SendMode(monotime.Now())) require.NotZero(t, sph.GetLossDetectionTimeout()) } func TestSentPacketHandlerAmplificationLimitClient(t *testing.T) { t.Run("handshake ACK", func(t *testing.T) { testSentPacketHandlerAmplificationLimitClient(t, false) }) t.Run("drop Handshake without ACK", func(t *testing.T) { testSentPacketHandlerAmplificationLimitClient(t, true) }) } func testSentPacketHandlerAmplificationLimitClient(t *testing.T, dropHandshake bool) { sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) require.Equal(t, SendAny, sph.SendMode(monotime.Now())) pn := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(monotime.Now(), pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, protocol.EncryptionInitial, protocol.ECNNon, 999, false, false) // it's not surprising that the loss detection timer is set, as this packet might be lost... require.NotZero(t, sph.GetLossDetectionTimeout()) // ... but it's still set after receiving an ACK for this packet, // since we might need to unblock the server's amplification limit _, err := sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pn)}, protocol.EncryptionInitial, monotime.Now()) require.NoError(t, err) require.NotZero(t, sph.GetLossDetectionTimeout()) require.Equal(t, SendAny, sph.SendMode(monotime.Now())) // when the timer expires, we should send a PTO packet sph.OnLossDetectionTimeout(monotime.Now()) require.Equal(t, SendPTOInitial, sph.SendMode(monotime.Now())) require.NotZero(t, sph.GetLossDetectionTimeout()) if dropHandshake { // dropping the handshake packet number space completes the handshake, // even if no ACK for a handshake packet was received sph.DropPackets(protocol.EncryptionHandshake, monotime.Now()) require.Zero(t, sph.GetLossDetectionTimeout()) return } // once the Initial packet number space is dropped, we need to send a Handshake PTO packet, // even if we haven't sent any packet in the Handshake packet number space yet sph.DropPackets(protocol.EncryptionInitial, monotime.Now()) require.NotZero(t, sph.GetLossDetectionTimeout()) sph.OnLossDetectionTimeout(monotime.Now()) require.Equal(t, SendPTOHandshake, sph.SendMode(monotime.Now())) // receiving an ACK for a handshake packet shows that the server completed address validation pn = sph.PopPacketNumber(protocol.EncryptionHandshake) sph.SentPacket(monotime.Now(), pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, protocol.EncryptionHandshake, protocol.ECNNon, 999, false, false) require.NotZero(t, sph.GetLossDetectionTimeout()) _, err = sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pn)}, protocol.EncryptionHandshake, monotime.Now()) require.NoError(t, err) require.Zero(t, sph.GetLossDetectionTimeout()) } func TestSentPacketHandlerDelayBasedLossDetection(t *testing.T) { rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveServer, nil, utils.DefaultLogger, ) var packets packetTracker sendPacket := func(t *testing.T, ti monotime.Time, isPathMTUProbePacket bool) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.EncryptionInitial, protocol.ECNNon, 1000, isPathMTUProbePacket, false) return pn } const rtt = time.Second now := monotime.Now() t1 := now.Add(-rtt) t2 := now.Add(-10 * time.Millisecond) // Send 3 packets pn1 := sendPacket(t, t1, false) pn2 := sendPacket(t, t2, false) // Also send a Path MTU probe packet. // We expect the same loss recovery logic to be applied to it. pn3 := sendPacket(t, t2, true) pn4 := sendPacket(t, now, false) _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pn4)}, protocol.EncryptionInitial, now.Add(time.Second), ) require.NoError(t, err) // make sure that the RTT is actually 1s require.Equal(t, rtt, rttStats.SmoothedRTT()) require.Equal(t, []protocol.PacketNumber{pn4}, packets.Acked) // only the first packet was lost require.Equal(t, []protocol.PacketNumber{pn1}, packets.Lost) // ... but we armed a timer to declare packet 2 lost after 9/8 RTTs require.Equal(t, t2.Add(time.Second*9/8), sph.GetLossDetectionTimeout()) sph.OnLossDetectionTimeout(sph.GetLossDetectionTimeout().Add(-time.Microsecond)) require.Len(t, packets.Lost, 1) sph.OnLossDetectionTimeout(sph.GetLossDetectionTimeout()) require.Equal(t, []protocol.PacketNumber{pn1, pn2, pn3}, packets.Lost) } func TestSentPacketHandlerPacketBasedLossDetection(t *testing.T) { rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveServer, nil, utils.DefaultLogger, ) var packets packetTracker now := monotime.Now() var pns []protocol.PacketNumber for range 5 { pn := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.EncryptionInitial, protocol.ECNNon, 1000, false, false) pns = append(pns, pn) } _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[3])}, protocol.EncryptionInitial, now.Add(time.Second), ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[3]}, packets.Acked) require.Equal(t, []protocol.PacketNumber{pns[0]}, packets.Lost) _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[4])}, protocol.EncryptionInitial, now.Add(time.Second), ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[3], pns[4]}, packets.Acked) require.Equal(t, []protocol.PacketNumber{pns[0], pns[1]}, packets.Lost) } func TestSentPacketHandlerPTO(t *testing.T) { t.Run("Initial", func(t *testing.T) { testSentPacketHandlerPTO(t, protocol.EncryptionInitial, SendPTOInitial) }) t.Run("Handshake", func(t *testing.T) { testSentPacketHandlerPTO(t, protocol.EncryptionHandshake, SendPTOHandshake) }) t.Run("1-RTT", func(t *testing.T) { testSentPacketHandlerPTO(t, protocol.Encryption1RTT, SendPTOAppData) }) } func testSentPacketHandlerPTO(t *testing.T, encLevel protocol.EncryptionLevel, ptoMode SendMode) { var packets packetTracker var eventRecorder events.Recorder rttStats := utils.NewRTTStats() rttStats.SetMaxAckDelay(25 * time.Millisecond) rttStats.UpdateRTT(500*time.Millisecond, 0) rttStats.UpdateRTT(1000*time.Millisecond, 0) rttStats.UpdateRTT(1500*time.Millisecond, 0) sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveServer, &eventRecorder, utils.DefaultLogger, ) // in the application-data packet number space, the PTO is only set if encLevel == protocol.Encryption1RTT { sph.DropPackets(protocol.EncryptionInitial, monotime.Now()) sph.DropPackets(protocol.EncryptionHandshake, monotime.Now()) } sendPacket := func(t *testing.T, ti monotime.Time, ackEliciting bool, ptoCount uint) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(encLevel) if ackEliciting { sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, encLevel, protocol.ECNNon, 1000, false, false) require.Equal(t, []qlogwriter.Event{ qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeSet, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, Time: ti.ToTime().Add(rttStats.PTO(encLevel == protocol.Encryption1RTT) << ptoCount), }, }, eventRecorder.Events(qlog.LossTimerUpdated{}), ) eventRecorder.Clear() } else { sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, nil, encLevel, protocol.ECNNon, 1000, true, false) require.Empty(t, eventRecorder.Events(qlog.LossTimerUpdated{})) } return pn } now := monotime.Now() sendTimes := []monotime.Time{ now, now.Add(100 * time.Millisecond), now.Add(200 * time.Millisecond), now.Add(300 * time.Millisecond), } var pns []protocol.PacketNumber // send packet 0, 1, 2, 3 for i := range 3 { pns = append(pns, sendPacket(t, sendTimes[i], true, 0)) } pns = append(pns, sendPacket(t, sendTimes[3], false, 0)) // The PTO includes the max_ack_delay only for the application-data packet number space. // Make sure that the value is actually different, so this test is meaningful. require.NotEqual(t, rttStats.PTO(true), rttStats.PTO(false)) timeout := sph.GetLossDetectionTimeout() // the PTO is based on the *last* ack-eliciting packet require.Equal(t, sendTimes[2].Add(rttStats.PTO(encLevel == protocol.Encryption1RTT)), timeout) eventRecorder.Clear() sph.OnLossDetectionTimeout(timeout) require.Equal(t, []qlogwriter.Event{ qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeExpired, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, }, qlog.PTOCountUpdated{PTOCount: 1}, qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeSet, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, Time: sendTimes[2].Add(2 * rttStats.PTO(encLevel == protocol.Encryption1RTT)).ToTime(), }, }, eventRecorder.Events(qlog.PTOCountUpdated{}, qlog.LossTimerUpdated{}), ) // PTO timer expiration doesn't declare packets lost require.Empty(t, packets.Lost) now = timeout require.Equal(t, ptoMode, sph.SendMode(now)) // queue a probe packet require.True(t, sph.QueueProbePacket(encLevel)) require.True(t, sph.QueueProbePacket(encLevel)) require.True(t, sph.QueueProbePacket(encLevel)) // there are only two ack-eliciting packets that could be queued require.False(t, sph.QueueProbePacket(encLevel)) // Queueing probe packets currently works by declaring them lost. // We shouldn't do this, but this is how the code is currently written. require.Equal(t, pns[:3], packets.Lost) packets.Lost = packets.Lost[:0] eventRecorder.Clear() // send packet 4 and 6 as probe packets // 5 doesn't count, since it's not an ack-eliciting packet sendTimes = append(sendTimes, now.Add(100*time.Millisecond)) sendTimes = append(sendTimes, now.Add(200*time.Millisecond)) sendTimes = append(sendTimes, now.Add(300*time.Millisecond)) require.Equal(t, ptoMode, sph.SendMode(sendTimes[4])) // first probe packet pns = append(pns, sendPacket(t, sendTimes[4], true, 1)) require.Equal(t, ptoMode, sph.SendMode(sendTimes[5])) // next probe packet pns = append(pns, sendPacket(t, sendTimes[5], false, 1)) require.Equal(t, ptoMode, sph.SendMode(sendTimes[6])) // non-ack-eliciting packet didn't count as a probe packet pns = append(pns, sendPacket(t, sendTimes[6], true, 1)) require.Equal(t, SendAny, sph.SendMode(sendTimes[6])) // enough probe packets sent timeout = sph.GetLossDetectionTimeout() // exponential backoff require.Equal(t, sendTimes[6].Add(2*rttStats.PTO(encLevel == protocol.Encryption1RTT)), timeout) now = timeout sph.OnLossDetectionTimeout(timeout) require.Equal(t, []qlogwriter.Event{ qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeExpired, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, }, qlog.PTOCountUpdated{PTOCount: 2}, qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeSet, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, Time: sendTimes[6].Add(4 * rttStats.PTO(encLevel == protocol.Encryption1RTT)).ToTime(), }, }, eventRecorder.Events(qlog.LossTimerUpdated{}, qlog.PTOCountUpdated{}), ) eventRecorder.Clear() // PTO timer expiration doesn't declare packets lost require.Empty(t, packets.Lost) // send packet 7, 8 as probe packets sendTimes = append(sendTimes, now.Add(100*time.Millisecond)) sendTimes = append(sendTimes, now.Add(200*time.Millisecond)) require.Equal(t, ptoMode, sph.SendMode(sendTimes[7])) // first probe packet pns = append(pns, sendPacket(t, sendTimes[7], true, 2)) require.Equal(t, ptoMode, sph.SendMode(sendTimes[8])) // next probe packet pns = append(pns, sendPacket(t, sendTimes[8], true, 2)) require.Equal(t, SendAny, sph.SendMode(sendTimes[8])) // enough probe packets sent timeout = sph.GetLossDetectionTimeout() // exponential backoff, again require.Equal(t, sendTimes[8].Add(4*rttStats.PTO(encLevel == protocol.Encryption1RTT)), timeout) eventRecorder.Clear() // Receive an ACK for packet 7. // This now declares packets lost, and leads to arming of a timer for packet 8. _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[7])}, encLevel, sendTimes[7].Add(time.Microsecond), ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[7]}, packets.Acked) require.Equal(t, []protocol.PacketNumber{pns[4], pns[6]}, packets.Lost) require.Len(t, eventRecorder.Events(qlog.PacketLost{}), 2) require.Equal(t, []qlogwriter.Event{ qlog.PTOCountUpdated{PTOCount: 0}, }, eventRecorder.Events(qlog.PTOCountUpdated{})[:1], ) require.Equal(t, []qlogwriter.Event{ qlog.LossTimerUpdated{ Type: qlog.LossTimerUpdateTypeSet, TimerType: qlog.TimerTypePTO, EncLevel: encLevel, Time: sendTimes[8].Add(rttStats.PTO(encLevel == protocol.Encryption1RTT)).ToTime(), }, }, eventRecorder.Events(qlog.LossTimerUpdated{}), ) require.Contains(t, packets.Acked, pns[7]) // The PTO timer is now set for the last remaining packet (8), // with no exponential backoff. require.Equal(t, sendTimes[8].Add(rttStats.PTO(encLevel == protocol.Encryption1RTT)), sph.GetLossDetectionTimeout()) // Acknowledge the last packet (8). // This should cancel the loss detection timer since there are no more outstanding packets. eventRecorder.Clear() _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[8])}, encLevel, sendTimes[8].Add(time.Second), ) require.NoError(t, err) require.Contains(t, packets.Acked, pns[8]) // The loss detection timer should be cancelled since there are no more outstanding packets. require.True(t, sph.GetLossDetectionTimeout().IsZero()) require.Equal(t, []qlogwriter.Event{ qlog.LossTimerUpdated{Type: qlog.LossTimerUpdateTypeCancelled}, }, eventRecorder.Events(qlog.LossTimerUpdated{}), ) } func TestSentPacketHandlerPacketNumberSpacesPTO(t *testing.T) { rttStats := utils.NewRTTStats() const rtt = time.Second rttStats.UpdateRTT(rtt, 0) sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveServer, nil, utils.DefaultLogger, ) sendPacket := func(t *testing.T, ti monotime.Time, encLevel protocol.EncryptionLevel) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(encLevel) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{{Frame: &wire.PingFrame{}}}, encLevel, protocol.ECNNon, 1000, false, false) return pn } var initialPNs, handshakePNs [4]protocol.PacketNumber var initialTimes, handshakeTimes [4]monotime.Time now := monotime.Now() initialPNs[0] = sendPacket(t, now, protocol.EncryptionInitial) initialTimes[0] = now now = now.Add(100 * time.Millisecond) handshakePNs[0] = sendPacket(t, now, protocol.EncryptionHandshake) handshakeTimes[0] = now now = now.Add(100 * time.Millisecond) initialPNs[1] = sendPacket(t, now, protocol.EncryptionInitial) initialTimes[1] = now now = now.Add(100 * time.Millisecond) handshakePNs[1] = sendPacket(t, now, protocol.EncryptionHandshake) handshakeTimes[1] = now require.Equal(t, protocol.ByteCount(4000), sph.(*sentPacketHandler).getBytesInFlight()) // the PTO is the earliest time of the PTO times for both packet number spaces, // i.e. the 2nd Initial packet sent timeout := sph.GetLossDetectionTimeout() require.Equal(t, initialTimes[1].Add(rttStats.PTO(false)), timeout) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOInitial, sph.SendMode(timeout)) // send a PTO probe packet (Initial) now = timeout.Add(100 * time.Millisecond) initialPNs[2] = sendPacket(t, now, protocol.EncryptionInitial) initialTimes[2] = now // the earliest PTO time is now the 2nd Handshake packet timeout = sph.GetLossDetectionTimeout() // pto_count is a global property, so there's now an exponential backoff require.Equal(t, handshakeTimes[1].Add(2*rttStats.PTO(false)), timeout) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOHandshake, sph.SendMode(timeout)) // send a PTO probe packet (Handshake) now = timeout.Add(100 * time.Millisecond) handshakePNs[2] = sendPacket(t, now, protocol.EncryptionHandshake) handshakeTimes[2] = now // the earliest PTO time is now the 3rd Initial packet timeout = sph.GetLossDetectionTimeout() require.Equal(t, initialTimes[2].Add(4*rttStats.PTO(false)), timeout) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOInitial, sph.SendMode(timeout)) // drop the Initial packet number space now = timeout.Add(100 * time.Millisecond) require.Equal(t, protocol.ByteCount(6000), sph.(*sentPacketHandler).getBytesInFlight()) sph.DropPackets(protocol.EncryptionInitial, now) require.Equal(t, protocol.ByteCount(3000), sph.(*sentPacketHandler).getBytesInFlight()) // Since the Initial packets are gone: // * the earliest PTO time is now based on the 3rd Handshake packet // * the PTO count is reset to 0 timeout = sph.GetLossDetectionTimeout() require.Equal(t, handshakeTimes[2].Add(rttStats.PTO(false)), timeout) // send a 1-RTT packet now = timeout.Add(100 * time.Millisecond) sendTime := now sendPacket(t, now, protocol.Encryption1RTT) // until handshake confirmation, the PTO timer is based on the Handshake packet number space require.Equal(t, timeout, sph.GetLossDetectionTimeout()) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOHandshake, sph.SendMode(now)) // Drop Handshake packet number space. // This confirms the handshake, and the PTO timer is now based on the 1-RTT packet number space. sph.DropPackets(protocol.EncryptionHandshake, now) require.Equal(t, sendTime.Add(rttStats.PTO(false)), sph.GetLossDetectionTimeout()) } func TestSentPacketHandler0RTT(t *testing.T) { sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) var appDataPackets packetTracker sendPacket := func(t *testing.T, ti monotime.Time, encLevel protocol.EncryptionLevel) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(encLevel) var frames []Frame if encLevel == protocol.Encryption0RTT || encLevel == protocol.Encryption1RTT { frames = []Frame{appDataPackets.NewPingFrame(pn)} } else { frames = []Frame{{Frame: &wire.PingFrame{}}} } sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, frames, encLevel, protocol.ECNNon, 1000, false, false) return pn } now := monotime.Now() sendPacket(t, now, protocol.Encryption0RTT) sendPacket(t, now.Add(100*time.Millisecond), protocol.EncryptionHandshake) sendPacket(t, now.Add(200*time.Millisecond), protocol.Encryption0RTT) sendPacket(t, now.Add(300*time.Millisecond), protocol.Encryption1RTT) sendPacket(t, now.Add(400*time.Millisecond), protocol.Encryption1RTT) require.Equal(t, protocol.ByteCount(5000), sph.(*sentPacketHandler).getBytesInFlight()) // The PTO timer is based on the Handshake packet number space, not the 0-RTT packets timeout := sph.GetLossDetectionTimeout() require.NotZero(t, timeout) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOHandshake, sph.SendMode(timeout)) now = timeout.Add(100 * time.Millisecond) sph.DropPackets(protocol.Encryption0RTT, now) require.Equal(t, protocol.ByteCount(3000), sph.(*sentPacketHandler).getBytesInFlight()) // 0-RTT are discarded, not lost require.Empty(t, appDataPackets.Lost) } func TestSentPacketHandlerCongestion(t *testing.T) { mockCtrl := gomock.NewController(t) cong := mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveServer, nil, utils.DefaultLogger, ) sph.(*sentPacketHandler).congestion = cong var packets packetTracker // Send the first 5 packets: not congestion-limited, not pacing-limited. // The 2nd packet is a Path MTU Probe packet. now := monotime.Now() var bytesInFlight protocol.ByteCount var pns []protocol.PacketNumber var sendTimes []monotime.Time for i := range 5 { gomock.InOrder( cong.EXPECT().CanSend(bytesInFlight).Return(true), cong.EXPECT().HasPacingBudget(now).Return(true), ) require.Equal(t, SendAny, sph.SendMode(now)) pn := sph.PopPacketNumber(protocol.EncryptionInitial) bytesInFlight += 1000 cong.EXPECT().OnPacketSent(now, bytesInFlight, pn, protocol.ByteCount(1000), true) sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.EncryptionInitial, protocol.ECNNon, 1000, i == 1, false) pns = append(pns, pn) sendTimes = append(sendTimes, now) now = now.Add(100 * time.Millisecond) } // try to send another packet: not congestion-limited, but pacing-limited now = now.Add(100 * time.Millisecond) gomock.InOrder( cong.EXPECT().CanSend(bytesInFlight).Return(true), cong.EXPECT().HasPacingBudget(now).Return(false), ) require.Equal(t, SendPacingLimited, sph.SendMode(now)) // the connection would call TimeUntilSend, to find out when a new packet can be sent again pacingDeadline := now.Add(500 * time.Millisecond) cong.EXPECT().TimeUntilSend(bytesInFlight).Return(pacingDeadline) require.Equal(t, pacingDeadline, sph.TimeUntilSend()) // try to send another packet, but now we're congestion limited now = now.Add(100 * time.Millisecond) cong.EXPECT().CanSend(bytesInFlight).Return(false) require.Equal(t, SendAck, sph.SendMode(now)) // ACKs are allowed even if congestion limited // Receive an ACK for packet 3 and 4 (which declares the 1st and 2nd packet lost). // However, since the 2nd packet was a Path MTU probe packet, it won't get reported // to the congestion controller. ackTime := sendTimes[3].Add(time.Second) gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), cong.EXPECT().OnCongestionEvent(pns[0], protocol.ByteCount(1000), protocol.ByteCount(5000)), cong.EXPECT().OnPacketAcked(pns[2], protocol.ByteCount(1000), protocol.ByteCount(5000), ackTime), cong.EXPECT().OnPacketAcked(pns[3], protocol.ByteCount(1000), protocol.ByteCount(5000), ackTime), ) _, err := sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns[2], pns[3])}, protocol.EncryptionInitial, ackTime) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[2], pns[3]}, packets.Acked) require.Equal(t, []protocol.PacketNumber{pns[0], pns[1]}, packets.Lost) // Now receive a (delayed) ACK for the 1st packet. // Since this packet was already lost, we don't expect any calls to the congestion controller. _, err = sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns[0])}, protocol.EncryptionInitial, ackTime) require.NoError(t, err) // we should now have a PTO timer armed for the 4th packet timeout := sph.GetLossDetectionTimeout() require.NotZero(t, timeout) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOInitial, sph.SendMode(timeout)) // send another packet to check that bytes_in_flight was correctly adjusted now = timeout.Add(100 * time.Millisecond) pn := sph.PopPacketNumber(protocol.EncryptionInitial) cong.EXPECT().OnPacketSent(now, protocol.ByteCount(2000), pn, protocol.ByteCount(1000), true) sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.EncryptionInitial, protocol.ECNNon, 1000, false, false) } func TestSentPacketHandlerRetry(t *testing.T) { t.Run("long RTT measurement", func(t *testing.T) { testSentPacketHandlerRetry(t, time.Second, time.Second) }) // The estimated RTT should be at least 5ms, even if the RTT measurement is very short. t.Run("short RTT measurement", func(t *testing.T) { testSentPacketHandlerRetry(t, minRTTAfterRetry/3, minRTTAfterRetry) }) } func testSentPacketHandlerRetry(t *testing.T, rtt, expectedRTT time.Duration) { var initialPackets, appDataPackets packetTracker rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) start := monotime.Now() now := start var initialPNs, appDataPNs []protocol.PacketNumber // send 2 initial and 2 0-RTT packets for range 2 { pn := sph.PopPacketNumber(protocol.EncryptionInitial) initialPNs = append(initialPNs, pn) sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{initialPackets.NewPingFrame(pn)}, protocol.EncryptionInitial, protocol.ECNNon, 1000, false, false) now = now.Add(100 * time.Millisecond) pn = sph.PopPacketNumber(protocol.Encryption0RTT) appDataPNs = append(appDataPNs, pn) sph.SentPacket(now, pn, protocol.InvalidPacketNumber, nil, []Frame{appDataPackets.NewPingFrame(pn)}, protocol.Encryption0RTT, protocol.ECNNon, 1000, false, false) now = now.Add(100 * time.Millisecond) } require.Equal(t, protocol.ByteCount(4000), sph.(*sentPacketHandler).getBytesInFlight()) require.NotZero(t, sph.GetLossDetectionTimeout()) sph.ResetForRetry(start.Add(rtt)) // receiving a Retry cancels all timers require.Zero(t, sph.GetLossDetectionTimeout()) // all packets sent so far are declared lost require.Equal(t, []protocol.PacketNumber{initialPNs[0], initialPNs[1]}, initialPackets.Lost) require.Equal(t, []protocol.PacketNumber{appDataPNs[0], appDataPNs[1]}, appDataPackets.Lost) require.False(t, sph.QueueProbePacket(protocol.EncryptionInitial)) require.False(t, sph.QueueProbePacket(protocol.Encryption0RTT)) // the RTT measurement is taken from the first packet sent require.Equal(t, expectedRTT, rttStats.SmoothedRTT()) require.Zero(t, sph.(*sentPacketHandler).getBytesInFlight()) // packet numbers continue increasing initialPN, _ := sph.PeekPacketNumber(protocol.EncryptionInitial) require.Greater(t, initialPN, initialPNs[1]) appDataPN, _ := sph.PeekPacketNumber(protocol.Encryption0RTT) require.Greater(t, appDataPN, appDataPNs[1]) } func TestSentPacketHandlerRetryAfterPTO(t *testing.T) { rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) var packets packetTracker start := monotime.Now() now := start pn1 := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(now, pn1, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn1)}, protocol.EncryptionInitial, protocol.ECNNon, 1000, false, false) timeout := sph.GetLossDetectionTimeout() require.NotZero(t, timeout) sph.OnLossDetectionTimeout(timeout) require.Equal(t, SendPTOInitial, sph.SendMode(timeout)) require.True(t, sph.QueueProbePacket(protocol.EncryptionInitial)) // send a retransmission for the first packet now = timeout.Add(100 * time.Millisecond) pn2 := sph.PopPacketNumber(protocol.EncryptionInitial) sph.SentPacket(now, pn2, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn2)}, protocol.EncryptionInitial, protocol.ECNNon, 900, false, false) const rtt = time.Second sph.ResetForRetry(now.Add(rtt)) require.Equal(t, []protocol.PacketNumber{pn1, pn2}, packets.Lost) // no RTT measurement is taken, since the PTO timer fired require.Equal(t, utils.DefaultInitialRTT, rttStats.SmoothedRTT()) } func TestSentPacketHandlerECN(t *testing.T) { mockCtrl := gomock.NewController(t) cong := mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl) cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() cong.EXPECT().OnPacketAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() cong.EXPECT().MaybeExitSlowStart().AnyTimes() ecnHandler := NewMockECNHandler(mockCtrl) sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) sph.(*sentPacketHandler).ecnTracker = ecnHandler sph.(*sentPacketHandler).congestion = cong // ECN marks on non-1-RTT packets are ignored sph.SentPacket(monotime.Now(), sph.PopPacketNumber(protocol.EncryptionInitial), protocol.InvalidPacketNumber, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false, false) sph.SentPacket(monotime.Now(), sph.PopPacketNumber(protocol.EncryptionHandshake), protocol.InvalidPacketNumber, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false, false) sph.SentPacket(monotime.Now(), sph.PopPacketNumber(protocol.Encryption0RTT), protocol.InvalidPacketNumber, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false, false) var packets packetTracker sendPacket := func(t *testing.T, ti monotime.Time, ecn protocol.ECN) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.Encryption1RTT) ecnHandler.EXPECT().SentPacket(pn, ecn) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, ecn, 1200, false, false) return pn } pns := make([]protocol.PacketNumber, 4) now := monotime.Now() pns[0] = sendPacket(t, now, protocol.ECT1) now = now.Add(time.Second) pns[1] = sendPacket(t, now, protocol.ECT0) pns[2] = sendPacket(t, now, protocol.ECT0) pns[3] = sendPacket(t, now, protocol.ECT0) // Receive an ACK with a short RTT, such that the first packet is lost. cong.EXPECT().OnCongestionEvent(gomock.Any(), gomock.Any(), gomock.Any()) ecnHandler.EXPECT().LostPacket(pns[0]) ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(10), int64(11), int64(12)).DoAndReturn(func(packets []packetWithPacketNumber, _, _, _ int64) bool { require.Len(t, packets, 2) require.Equal(t, pns[2], packets[0].PacketNumber) require.Equal(t, pns[3], packets[1].PacketNumber) return false }) _, err := sph.ReceivedAck( &wire.AckFrame{ AckRanges: ackRanges(pns[2], pns[3]), ECT0: 10, ECT1: 11, ECNCE: 12, }, protocol.Encryption1RTT, now.Add(100*time.Millisecond), ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[0]}, packets.Lost) // The second packet is still outstanding. // Receive a (delayed) ACK for it. // Since the new ECN counts were already reported, ECN marks on this ACK frame are ignored. now = now.Add(100 * time.Millisecond) _, err = sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns[1])}, protocol.Encryption1RTT, now) require.NoError(t, err) // Send two more packets, and receive an ACK for the second one. pns = pns[:2] pns[0] = sendPacket(t, now, protocol.ECT1) pns[1] = sendPacket(t, now, protocol.ECT1) ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(packets []packetWithPacketNumber, _, _, _ int64) bool { require.Len(t, packets, 1) require.Equal(t, pns[1], packets[0].PacketNumber) return false }, ) now = now.Add(100 * time.Millisecond) _, err = sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns[1])}, protocol.Encryption1RTT, now) require.NoError(t, err) // Receiving an ACK that covers both packets doesn't cause the ECN marks to be reported, // since the largest acked didn't increase. now = now.Add(100 * time.Millisecond) _, err = sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns[0], pns[1])}, protocol.Encryption1RTT, now) require.NoError(t, err) // Send another packet, and have the ECN handler report congestion. // This needs to be reported to the congestion controller. pns = pns[:1] now = now.Add(time.Second) pns[0] = sendPacket(t, now, protocol.ECT1) gomock.InOrder( ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(true), cong.EXPECT().OnCongestionEvent(pns[0], protocol.ByteCount(0), gomock.Any()), ) _, err = sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(pns[0])}, protocol.Encryption1RTT, now.Add(100*time.Millisecond)) require.NoError(t, err) } func TestSentPacketHandlerPathProbe(t *testing.T) { const rtt = 10 * time.Millisecond // RTT of the original path rttStats := utils.NewRTTStats() rttStats.UpdateRTT(rtt, 0) sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) sph.DropPackets(protocol.EncryptionInitial, monotime.Now()) sph.DropPackets(protocol.EncryptionHandshake, monotime.Now()) var packets packetTracker sendPacket := func(t *testing.T, ti monotime.Time, isPathProbe bool) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.Encryption1RTT) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, isPathProbe) return pn } // send 5 packets: 2 non-probe packets, 1 probe packet, 2 non-probe packets now := monotime.Now() var pns [5]protocol.PacketNumber pns[0] = sendPacket(t, now, false) now = now.Add(rtt) pns[1] = sendPacket(t, now, false) pns[2] = sendPacket(t, now, true) pathProbeTimeout := now.Add(pathProbePacketLossTimeout) now = now.Add(rtt) pns[3] = sendPacket(t, now, false) now = now.Add(rtt) pns[4] = sendPacket(t, now, false) require.Less(t, sph.GetLossDetectionTimeout(), pathProbeTimeout) now = now.Add(100 * time.Millisecond) // make sure that this ACK doesn't declare the path probe packet lost require.Greater(t, pathProbeTimeout, now) _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[0], pns[3], pns[4])}, protocol.Encryption1RTT, now, ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[0], pns[3], pns[4]}, packets.Acked) // despite having been sent at the same time, the probe packet was not lost require.Equal(t, []protocol.PacketNumber{pns[1]}, packets.Lost) // the timeout is now based on the probe packet timeout := sph.GetLossDetectionTimeout() require.Equal(t, pathProbeTimeout, timeout) require.Zero(t, sph.(*sentPacketHandler).getBytesInFlight()) pn1 := sendPacket(t, now, false) pn2 := sendPacket(t, now, false) require.Equal(t, protocol.ByteCount(2400), sph.(*sentPacketHandler).getBytesInFlight()) // send one more non-probe packet pn := sendPacket(t, now, false) // the timeout is now based on this packet require.Less(t, sph.GetLossDetectionTimeout(), pathProbeTimeout) _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[2], pn)}, protocol.Encryption1RTT, now, ) require.NoError(t, err) packets.Lost = packets.Lost[:0] sph.MigratedPath(now, 1200) require.Zero(t, sph.(*sentPacketHandler).getBytesInFlight()) require.Equal(t, utils.DefaultInitialRTT, rttStats.SmoothedRTT()) require.Equal(t, []protocol.PacketNumber{pn1, pn2}, packets.Lost) } func TestSentPacketHandlerPathProbeAckAndLoss(t *testing.T) { const rtt = 10 * time.Millisecond // RTT of the original path rttStats := utils.NewRTTStats() rttStats.UpdateRTT(rtt, 0) sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) sph.DropPackets(protocol.EncryptionInitial, monotime.Now()) sph.DropPackets(protocol.EncryptionHandshake, monotime.Now()) var packets packetTracker sendPacket := func(t *testing.T, ti monotime.Time, isPathProbe bool) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.Encryption1RTT) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, isPathProbe) return pn } now := monotime.Now() pn1 := sendPacket(t, now, true) t1 := now now = now.Add(100 * time.Millisecond) _ = sendPacket(t, now, true) t2 := now now = now.Add(100 * time.Millisecond) pn3 := sendPacket(t, now, true) now = now.Add(100 * time.Millisecond) require.Equal(t, t1.Add(pathProbePacketLossTimeout), sph.GetLossDetectionTimeout()) require.NoError(t, sph.OnLossDetectionTimeout(sph.GetLossDetectionTimeout())) require.Equal(t, []protocol.PacketNumber{pn1}, packets.Lost) packets.Lost = packets.Lost[:0] // receive a delayed ACK for the path probe packet _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pn1, pn3)}, protocol.Encryption1RTT, now, ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pn3}, packets.Acked) require.Empty(t, packets.Lost) require.Equal(t, t2.Add(pathProbePacketLossTimeout), sph.GetLossDetectionTimeout()) } // The packet tracking logic is pretty complex. // We test it with a randomized approach, to make sure that it doesn't panic under any circumstances. func TestSentPacketHandlerRandomized(t *testing.T) { seed := uint64(time.Now().UnixNano()) for i := range 5 { t.Run(fmt.Sprintf("run %d (seed %d)", i+1, seed), func(t *testing.T) { testSentPacketHandlerRandomized(t, seed) }) seed++ } } func testSentPacketHandlerRandomized(t *testing.T, seed uint64) { var b [32]byte binary.BigEndian.PutUint64(b[:], seed) r := rand.New(rand.NewChaCha8(b)) rttStats := utils.NewRTTStats() rtt := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, 1000 * time.Millisecond}[r.IntN(3)] t.Logf("rtt: %dms", rtt.Milliseconds()) rttStats.UpdateRTT(rtt, 0) // RTT of the original path randDuration := func(min, max time.Duration) time.Duration { return time.Duration(rand.Int64N(int64(max-min))) + min } sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) sph.DropPackets(protocol.EncryptionInitial, monotime.Now()) sph.DropPackets(protocol.EncryptionHandshake, monotime.Now()) var packets packetTracker sendPacket := func(t *testing.T, ti monotime.Time, isPathProbe bool) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.Encryption1RTT) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, isPathProbe) return pn } now := monotime.Now() start := now var pns []protocol.PacketNumber for range 4 { isProbe := r.Int()%2 == 0 pn := sendPacket(t, now, isProbe) t.Logf("t=%dms: sending packet %d (probe packet: %t)", now.Sub(start).Milliseconds(), pn, isProbe) pns = append(pns, pn) now = now.Add(randDuration(0, 500*time.Millisecond)) if r.Int()%3 == 0 { sph.OnLossDetectionTimeout(now) t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost) packets.Reset() now = now.Add(randDuration(0, 500*time.Millisecond)) } if r.Int()%3 == 0 { // acknowledge up to 2 random packet numbers from the pns slice var ackPns []protocol.PacketNumber if len(pns) > 0 { numToAck := min(1+r.IntN(2), len(pns)) for range numToAck { ackPns = append(ackPns, pns[r.IntN(len(pns))]) } } if len(ackPns) > 1 { slices.Sort(ackPns) ackPns = slices.Compact(ackPns) } sph.ReceivedAck(&wire.AckFrame{AckRanges: ackRanges(ackPns...)}, protocol.Encryption1RTT, now) t.Logf("t=%dms: received ACK for packets %v (acked: %v, lost: %v)", now.Sub(start).Milliseconds(), ackPns, packets.Acked, packets.Lost) packets.Reset() now = now.Add(randDuration(0, 500*time.Millisecond)) } if r.Int()%10 == 0 { sph.MigratedPath(now, 1200) now = now.Add(randDuration(0, 500*time.Millisecond)) } } t.Logf("t=%dms: loss detection timeout (lost: %v)", now.Sub(start).Milliseconds(), packets.Lost) sph.OnLossDetectionTimeout(now) } func TestSentPacketHandlerSpuriousLoss(t *testing.T) { const rtt = time.Second var eventRecorder events.Recorder sph := NewSentPacketHandler( 0, 1200, utils.NewRTTStats(), &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, &eventRecorder, utils.DefaultLogger, ) var packets packetTracker sendPacket := func(t *testing.T, ti monotime.Time) protocol.PacketNumber { t.Helper() pn := sph.PopPacketNumber(protocol.Encryption1RTT) sph.SentPacket(ti, pn, protocol.InvalidPacketNumber, nil, []Frame{packets.NewPingFrame(pn)}, protocol.Encryption1RTT, protocol.ECNNon, 1000, false, false) return pn } start := monotime.Now() now := start var pns []protocol.PacketNumber for range 20 { pns = append(pns, sendPacket(t, now)) now = now.Add(10 * time.Millisecond) } now = start.Add(rtt) _, err := sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[0], pns[6])}, protocol.Encryption1RTT, now, ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[0], pns[6]}, packets.Acked) // pns[4] and pns[5] are not yet declared lost require.Equal(t, []protocol.PacketNumber{pns[1], pns[2], pns[3]}, packets.Lost) packets.Reset() eventRecorder.Clear() const secondAckDelay = 50 * time.Millisecond now = now.Add(secondAckDelay) _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[0], pns[1], pns[2], pns[3], pns[4], pns[5], pns[6], pns[12], pns[16])}, protocol.Encryption1RTT, now, ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[4], pns[5], pns[12], pns[16]}, packets.Acked) require.Equal(t, []protocol.PacketNumber{pns[7], pns[8], pns[9], pns[10], pns[11], pns[13]}, packets.Lost) require.Equal(t, []qlogwriter.Event{ qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[1], PacketReordering: 16 - 1, TimeReordering: rtt + secondAckDelay - 10*time.Millisecond, }, qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[2], PacketReordering: 16 - 2, TimeReordering: rtt + secondAckDelay - 20*time.Millisecond, }, qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[3], PacketReordering: 16 - 3, TimeReordering: rtt + secondAckDelay - 30*time.Millisecond, }, }, eventRecorder.Events(qlog.SpuriousLoss{}), ) eventRecorder.Clear() now = now.Add(secondAckDelay) _, err = sph.ReceivedAck( &wire.AckFrame{AckRanges: ackRanges(pns[0], pns[1], pns[2], pns[3], pns[4], pns[5], pns[6], pns[7], pns[8], pns[9], pns[10], pns[16], pns[17], pns[18])}, protocol.Encryption1RTT, now, ) require.NoError(t, err) require.Equal(t, []protocol.PacketNumber{pns[4], pns[5], pns[12], pns[16], pns[17], pns[18]}, packets.Acked) require.Equal(t, []protocol.PacketNumber{pns[7], pns[8], pns[9], pns[10], pns[11], pns[13], pns[14], pns[15]}, packets.Lost) require.Equal(t, []qlogwriter.Event{ qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[7], PacketReordering: 18 - 7, TimeReordering: rtt + 2*secondAckDelay - 70*time.Millisecond, }, qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[8], PacketReordering: 18 - 8, TimeReordering: rtt + 2*secondAckDelay - 80*time.Millisecond, }, qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[9], PacketReordering: 18 - 9, TimeReordering: rtt + 2*secondAckDelay - 90*time.Millisecond, }, qlog.SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: pns[10], PacketReordering: 18 - 10, TimeReordering: rtt + 2*secondAckDelay - 100*time.Millisecond, }, }, eventRecorder.Events(qlog.SpuriousLoss{}), ) } func BenchmarkSendAndAcknowledge(b *testing.B) { b.Run("ack every: 2, in flight: 0", func(b *testing.B) { benchmarkSendAndAcknowledge(b, 2, 0) }) b.Run("ack every: 10, in flight: 100", func(b *testing.B) { benchmarkSendAndAcknowledge(b, 10, 100) }) b.Run("ack every: 100, in flight: 1000", func(b *testing.B) { benchmarkSendAndAcknowledge(b, 100, 1000) }) } func benchmarkSendAndAcknowledge(b *testing.B, ackEvery, inFlight int) { b.ReportAllocs() rttStats := utils.NewRTTStats() sph := NewSentPacketHandler( 0, 1200, rttStats, &utils.ConnectionStats{}, true, false, nil, protocol.PerspectiveClient, nil, utils.DefaultLogger, ) now := monotime.Now() sph.DropPackets(protocol.EncryptionInitial, now) sph.DropPackets(protocol.EncryptionHandshake, now) streamFrames := []StreamFrame{{Frame: &wire.StreamFrame{}}} pns := make([]protocol.PacketNumber, 0, ackEvery+inFlight) var counter int ranges := make([]wire.AckRange, 0, ackEvery) for b.Loop() { counter++ pn := sph.PopPacketNumber(protocol.Encryption1RTT) sph.SentPacket( now, pn, protocol.InvalidPacketNumber, streamFrames, nil, protocol.Encryption1RTT, protocol.ECNNon, 1200, false, false, ) now = now.Add(time.Millisecond) pns = append(pns, pn) if counter > inFlight && counter%ackEvery == 0 { sph.ReceivedAck( &wire.AckFrame{AckRanges: appendAckRanges(ranges, pns[:ackEvery]...)}, protocol.Encryption1RTT, now, ) pns = append(pns[:0], pns[ackEvery:]...) ranges = ranges[:0] } } } quic-go-0.59.0/internal/ackhandler/sent_packet_history.go000066400000000000000000000160431513066070600235010ustar00rootroot00000000000000package ackhandler import ( "fmt" "iter" "slices" "github.com/quic-go/quic-go/internal/protocol" ) const maxSkippedPackets = 4 type sentPacketHistory struct { packets []*packet pathProbePackets []packetWithPacketNumber skippedPackets []protocol.PacketNumber numOutstanding int firstPacketNumber protocol.PacketNumber highestPacketNumber protocol.PacketNumber } func newSentPacketHistory(isAppData bool) *sentPacketHistory { h := &sentPacketHistory{ highestPacketNumber: protocol.InvalidPacketNumber, firstPacketNumber: protocol.InvalidPacketNumber, } if isAppData { h.packets = make([]*packet, 0, 32) h.skippedPackets = make([]protocol.PacketNumber, 0, maxSkippedPackets) } else { h.packets = make([]*packet, 0, 6) } return h } func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) { if h.highestPacketNumber != protocol.InvalidPacketNumber { if pn != h.highestPacketNumber+1 { panic("non-sequential packet number use") } } h.highestPacketNumber = pn if len(h.packets) == 0 { h.firstPacketNumber = pn } } func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) { h.checkSequentialPacketNumberUse(pn) if len(h.packets) > 0 { h.packets = append(h.packets, nil) } if len(h.skippedPackets) == maxSkippedPackets { h.skippedPackets = slices.Delete(h.skippedPackets, 0, 1) } h.skippedPackets = append(h.skippedPackets, pn) } func (h *sentPacketHistory) SentPacket(pn protocol.PacketNumber, p *packet) { h.checkSequentialPacketNumberUse(pn) h.packets = append(h.packets, p) if p.Outstanding() { h.numOutstanding++ } } func (h *sentPacketHistory) SentPathProbePacket(pn protocol.PacketNumber, p *packet) { h.checkSequentialPacketNumberUse(pn) h.packets = append(h.packets, &packet{isPathProbePacket: true}) h.pathProbePackets = append(h.pathProbePackets, packetWithPacketNumber{PacketNumber: pn, packet: p}) } func (h *sentPacketHistory) Packets() iter.Seq2[protocol.PacketNumber, *packet] { return func(yield func(protocol.PacketNumber, *packet) bool) { // h.firstPacketNumber might be updated in the yield function, // so we need to save it here. firstPacketNumber := h.firstPacketNumber for i, p := range h.packets { if p == nil { continue } if !yield(firstPacketNumber+protocol.PacketNumber(i), p) { return } } } } func (h *sentPacketHistory) PathProbes() iter.Seq2[protocol.PacketNumber, *packet] { return func(yield func(protocol.PacketNumber, *packet) bool) { for _, p := range h.pathProbePackets { if !yield(p.PacketNumber, p.packet) { return } } } } // FirstOutstanding returns the first outstanding packet. func (h *sentPacketHistory) FirstOutstanding() (protocol.PacketNumber, *packet) { if !h.HasOutstandingPackets() { return protocol.InvalidPacketNumber, nil } for i, p := range h.packets { if p != nil && p.Outstanding() { return h.firstPacketNumber + protocol.PacketNumber(i), p } } return protocol.InvalidPacketNumber, nil } // FirstOutstandingPathProbe returns the first outstanding path probe packet func (h *sentPacketHistory) FirstOutstandingPathProbe() (protocol.PacketNumber, *packet) { if len(h.pathProbePackets) == 0 { return protocol.InvalidPacketNumber, nil } return h.pathProbePackets[0].PacketNumber, h.pathProbePackets[0].packet } func (h *sentPacketHistory) SkippedPackets() iter.Seq[protocol.PacketNumber] { return func(yield func(protocol.PacketNumber) bool) { for _, p := range h.skippedPackets { if !yield(p) { return } } } } func (h *sentPacketHistory) Len() int { return len(h.packets) } func (h *sentPacketHistory) NumOutstanding() int { return h.numOutstanding } // Remove removes a packet from the sent packet history. // It must not be used for skipped packet numbers. func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error { idx, ok := h.getIndex(pn) if !ok { return fmt.Errorf("packet %d not found in sent packet history", pn) } p := h.packets[idx] if p.Outstanding() { h.numOutstanding-- if h.numOutstanding < 0 { panic("negative number of outstanding packets") } } h.packets[idx] = nil // clean up all skipped packets directly before this packet number var hasPacketBefore bool for idx > 0 { idx-- if h.packets[idx] != nil { hasPacketBefore = true break } } if !hasPacketBefore { h.cleanupStart() } if len(h.packets) > 0 && h.packets[0] == nil { panic("cleanup failed") } return nil } // RemovePathProbe removes a path probe packet. // It scales O(N), but that's ok, since we don't expect to send many path probe packets. // It is not valid to call this function in IteratePathProbes. func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet { var packetToDelete *packet idx := -1 for i, p := range h.pathProbePackets { if p.PacketNumber == pn { packetToDelete = p.packet idx = i break } } if idx != -1 { // don't use slices.Delete, because it zeros the deleted element copy(h.pathProbePackets[idx:], h.pathProbePackets[idx+1:]) h.pathProbePackets = h.pathProbePackets[:len(h.pathProbePackets)-1] } return packetToDelete } // getIndex gets the index of packet p in the packets slice. func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) { if len(h.packets) == 0 { return 0, false } if p < h.firstPacketNumber { return 0, false } index := int(p - h.firstPacketNumber) if index > len(h.packets)-1 { return 0, false } return index, true } func (h *sentPacketHistory) HasOutstandingPackets() bool { return h.numOutstanding > 0 } func (h *sentPacketHistory) HasOutstandingPathProbes() bool { return len(h.pathProbePackets) > 0 } // delete all nil entries at the beginning of the packets slice func (h *sentPacketHistory) cleanupStart() { for i, p := range h.packets { if p != nil { h.packets = h.packets[i:] h.firstPacketNumber += protocol.PacketNumber(i) return } } h.packets = h.packets[:0] h.firstPacketNumber = protocol.InvalidPacketNumber } func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber { if len(h.packets) == 0 { return protocol.InvalidPacketNumber } return h.firstPacketNumber } func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) { idx, ok := h.getIndex(pn) if !ok { return } p := h.packets[idx] if p.Outstanding() { h.numOutstanding-- if h.numOutstanding < 0 { panic("negative number of outstanding packets") } } h.packets[idx] = nil if idx == 0 { h.cleanupStart() } } // Difference returns the difference between two packet numbers a and b (a - b), // taking into account any skipped packet numbers between them. // // Note that old skipped packets are garbage collected at some point, // so this function is not guaranteed to return the correct result after a while. func (h *sentPacketHistory) Difference(a, b protocol.PacketNumber) protocol.PacketNumber { diff := a - b if len(h.skippedPackets) == 0 { return diff } if a < h.skippedPackets[0] || b > h.skippedPackets[len(h.skippedPackets)-1] { return diff } for _, p := range h.skippedPackets { if p > b && p < a { diff-- } } return diff } quic-go-0.59.0/internal/ackhandler/sent_packet_history_test.go000066400000000000000000000270741513066070600245460ustar00rootroot00000000000000package ackhandler import ( "slices" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func ackElicitingPacket() *packet { return &packet{StreamFrames: []StreamFrame{{Frame: &wire.StreamFrame{StreamID: 1}}}} } func (h *sentPacketHistory) getPacketNumbers() []protocol.PacketNumber { pns := make([]protocol.PacketNumber, 0, len(h.packets)) for pn := range h.Packets() { pns = append(pns, pn) } return pns } func TestSentPacketHistoryPacketTracking(t *testing.T) { t.Run("first packet ack-eliciting", func(t *testing.T) { testSentPacketHistoryPacketTracking(t, true) }) t.Run("first packet non-ack-eliciting", func(t *testing.T) { testSentPacketHistoryPacketTracking(t, false) }) } func testSentPacketHistoryPacketTracking(t *testing.T, firstPacketAckEliciting bool) { hist := newSentPacketHistory(true) require.False(t, hist.HasOutstandingPackets()) if firstPacketAckEliciting { hist.SentPacket(0, ackElicitingPacket()) require.True(t, hist.HasOutstandingPackets()) } else { hist.SentPacket(0, &packet{}) require.False(t, hist.HasOutstandingPackets()) } hist.SentPacket(1, ackElicitingPacket()) hist.SentPacket(2, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{0, 1, 2}, hist.getPacketNumbers()) require.Empty(t, slices.Collect(hist.SkippedPackets())) require.Equal(t, 3, hist.Len()) if firstPacketAckEliciting { require.Equal(t, 3, hist.NumOutstanding()) } else { require.Equal(t, 2, hist.NumOutstanding()) } // non-ack-eliciting packets are saved, but don't count as outstanding hist.SentPacket(3, &packet{}) hist.SentPacket(4, ackElicitingPacket()) hist.SentPacket(5, &packet{}) hist.SentPacket(6, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5, 6}, hist.getPacketNumbers()) if firstPacketAckEliciting { require.Equal(t, 5, hist.NumOutstanding()) } else { require.Equal(t, 4, hist.NumOutstanding()) } // handle skipped packet numbers hist.SkippedPacket(7) hist.SentPacket(8, ackElicitingPacket()) hist.SentPacket(9, &packet{}) hist.SkippedPacket(10) hist.SentPacket(11, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5, 6, 8, 9, 11}, hist.getPacketNumbers()) require.Equal(t, []protocol.PacketNumber{7, 10}, slices.Collect(hist.SkippedPackets())) require.Equal(t, 12, hist.Len()) if firstPacketAckEliciting { require.Equal(t, 7, hist.NumOutstanding()) } else { require.Equal(t, 6, hist.NumOutstanding()) } } func TestSentPacketHistoryNonSequentialPacketNumberUse(t *testing.T) { hist := newSentPacketHistory(true) hist.SentPacket(100, ackElicitingPacket()) require.Panics(t, func() { hist.SentPacket(102, ackElicitingPacket()) }) } func TestSentPacketHistoryRemovePackets(t *testing.T) { hist := newSentPacketHistory(true) hist.SentPacket(0, ackElicitingPacket()) hist.SentPacket(1, ackElicitingPacket()) hist.SkippedPacket(2) hist.SkippedPacket(3) hist.SentPacket(4, ackElicitingPacket()) hist.SkippedPacket(5) hist.SentPacket(6, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{0, 1, 4, 6}, hist.getPacketNumbers()) require.Equal(t, []protocol.PacketNumber{2, 3, 5}, slices.Collect(hist.SkippedPackets())) require.NoError(t, hist.Remove(0)) require.Equal(t, []protocol.PacketNumber{2, 3, 5}, slices.Collect(hist.SkippedPackets())) require.NoError(t, hist.Remove(1)) require.Equal(t, []protocol.PacketNumber{4, 6}, hist.getPacketNumbers()) // skipped packets should be preserved require.Equal(t, []protocol.PacketNumber{2, 3, 5}, slices.Collect(hist.SkippedPackets())) // add one more packet hist.SentPacket(7, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{4, 6, 7}, hist.getPacketNumbers()) // remove last packet and add another require.NoError(t, hist.Remove(7)) hist.SentPacket(8, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{4, 6, 8}, hist.getPacketNumbers()) // try to remove non-existent packet err := hist.Remove(9) require.Error(t, err) require.EqualError(t, err, "packet 9 not found in sent packet history") // only the last 4 skipped packets should be preserved hist.SkippedPacket(9) hist.SkippedPacket(10) hist.SentPacket(11, ackElicitingPacket()) hist.SkippedPacket(12) require.Equal(t, []protocol.PacketNumber{5, 9, 10, 12}, slices.Collect(hist.SkippedPackets())) // Remove all packets require.NoError(t, hist.Remove(4)) require.NoError(t, hist.Remove(6)) require.NoError(t, hist.Remove(8)) require.NoError(t, hist.Remove(11)) require.Empty(t, hist.getPacketNumbers()) require.Len(t, slices.Collect(hist.SkippedPackets()), 4) require.False(t, hist.HasOutstandingPackets()) } func TestSentPacketHistoryFirstOutstandingPacket(t *testing.T) { hist := newSentPacketHistory(true) pn, p := hist.FirstOutstanding() require.Equal(t, protocol.InvalidPacketNumber, pn) require.Nil(t, p) hist.SentPacket(2, ackElicitingPacket()) hist.SentPacket(3, ackElicitingPacket()) pn, p = hist.FirstOutstanding() require.Equal(t, protocol.PacketNumber(2), pn) require.NotNil(t, p) // remove the first packet hist.Remove(2) pn, p = hist.FirstOutstanding() require.Equal(t, protocol.PacketNumber(3), pn) require.NotNil(t, p) // Path MTU packets are not regarded as outstanding hist = newSentPacketHistory(true) hist.SentPacket(2, ackElicitingPacket()) hist.SkippedPacket(3) p = ackElicitingPacket() p.IsPathMTUProbePacket = true hist.SentPacket(4, p) pn, p = hist.FirstOutstanding() require.NotNil(t, p) require.Equal(t, protocol.PacketNumber(2), pn) } func TestSentPacketHistoryIterating(t *testing.T) { hist := newSentPacketHistory(true) hist.SkippedPacket(0) hist.SentPacket(1, ackElicitingPacket()) hist.SentPacket(2, ackElicitingPacket()) hist.SentPacket(3, ackElicitingPacket()) hist.SkippedPacket(4) hist.SkippedPacket(5) hist.SentPacket(6, ackElicitingPacket()) require.Equal(t, []protocol.PacketNumber{0, 4, 5}, slices.Collect(hist.SkippedPackets())) require.NoError(t, hist.Remove(3)) var packets []protocol.PacketNumber for pn, p := range hist.Packets() { require.NotNil(t, p) packets = append(packets, pn) } require.Equal(t, []protocol.PacketNumber{1, 2, 6}, packets) require.Equal(t, []protocol.PacketNumber{0, 4, 5}, slices.Collect(hist.SkippedPackets())) } func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) { hist := newSentPacketHistory(true) hist.SentPacket(0, ackElicitingPacket()) hist.SentPacket(1, ackElicitingPacket()) hist.SkippedPacket(2) hist.SentPacket(3, ackElicitingPacket()) hist.SkippedPacket(4) hist.SentPacket(5, ackElicitingPacket()) var iterations []protocol.PacketNumber for pn := range hist.Packets() { iterations = append(iterations, pn) switch pn { case 0: require.NoError(t, hist.Remove(0)) case 3: require.NoError(t, hist.Remove(3)) } } require.Equal(t, []protocol.PacketNumber{0, 1, 3, 5}, iterations) require.Equal(t, []protocol.PacketNumber{1, 5}, hist.getPacketNumbers()) require.Equal(t, []protocol.PacketNumber{2, 4}, slices.Collect(hist.SkippedPackets())) } func TestSentPacketHistoryPathProbes(t *testing.T) { hist := newSentPacketHistory(true) hist.SentPacket(0, ackElicitingPacket()) hist.SentPacket(1, ackElicitingPacket()) hist.SentPathProbePacket(2, ackElicitingPacket()) hist.SentPacket(3, ackElicitingPacket()) hist.SentPacket(4, ackElicitingPacket()) hist.SentPathProbePacket(5, ackElicitingPacket()) getPacketsInHistory := func(t *testing.T) []protocol.PacketNumber { t.Helper() var pns []protocol.PacketNumber for pn, p := range hist.Packets() { pns = append(pns, pn) switch pn { case 2, 5: require.True(t, p.isPathProbePacket) default: require.False(t, p.isPathProbePacket) } } return pns } getPacketsInPathProbeHistory := func(t *testing.T) []protocol.PacketNumber { t.Helper() var pns []protocol.PacketNumber for pn := range hist.PathProbes() { pns = append(pns, pn) } return pns } require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5}, getPacketsInHistory(t)) require.Equal(t, []protocol.PacketNumber{2, 5}, getPacketsInPathProbeHistory(t)) // Removing packets from the regular packet history might happen before the path probe // is declared lost, as the original path might have a smaller RTT than the path timeout. // Therefore, the path probe packet is not removed from the path probe history. require.NoError(t, hist.Remove(0)) require.NoError(t, hist.Remove(1)) require.NoError(t, hist.Remove(2)) require.NoError(t, hist.Remove(3)) require.Equal(t, []protocol.PacketNumber{4, 5}, getPacketsInHistory(t)) require.Equal(t, []protocol.PacketNumber{2, 5}, getPacketsInPathProbeHistory(t)) require.True(t, hist.HasOutstandingPackets()) require.True(t, hist.HasOutstandingPathProbes()) pn, p := hist.FirstOutstanding() require.Equal(t, protocol.PacketNumber(4), pn) require.NotNil(t, p) pn, p = hist.FirstOutstandingPathProbe() require.NotNil(t, p) require.Equal(t, protocol.PacketNumber(2), pn) hist.RemovePathProbe(2) require.Equal(t, []protocol.PacketNumber{4, 5}, getPacketsInHistory(t)) require.Equal(t, []protocol.PacketNumber{5}, getPacketsInPathProbeHistory(t)) require.True(t, hist.HasOutstandingPathProbes()) pn, p = hist.FirstOutstandingPathProbe() require.NotNil(t, p) require.Equal(t, protocol.PacketNumber(5), pn) hist.RemovePathProbe(5) require.Equal(t, []protocol.PacketNumber{4, 5}, getPacketsInHistory(t)) require.Empty(t, getPacketsInPathProbeHistory(t)) require.True(t, hist.HasOutstandingPackets()) require.False(t, hist.HasOutstandingPathProbes()) pn, p = hist.FirstOutstandingPathProbe() require.Equal(t, protocol.InvalidPacketNumber, pn) require.Nil(t, p) require.NoError(t, hist.Remove(4)) require.NoError(t, hist.Remove(5)) require.Empty(t, getPacketsInHistory(t)) require.False(t, hist.HasOutstandingPackets()) pn, p = hist.FirstOutstanding() require.Equal(t, protocol.InvalidPacketNumber, pn) require.Nil(t, p) // path probe packets are considered outstanding hist.SentPathProbePacket(6, ackElicitingPacket()) require.False(t, hist.HasOutstandingPackets()) require.True(t, hist.HasOutstandingPathProbes()) pn, p = hist.FirstOutstandingPathProbe() require.NotNil(t, p) require.Equal(t, protocol.PacketNumber(6), pn) hist.RemovePathProbe(6) require.False(t, hist.HasOutstandingPackets()) pn, p = hist.FirstOutstanding() require.Equal(t, protocol.InvalidPacketNumber, pn) require.Nil(t, p) require.False(t, hist.HasOutstandingPathProbes()) pn, p = hist.FirstOutstandingPathProbe() require.Equal(t, protocol.InvalidPacketNumber, pn) require.Nil(t, p) } func TestSentPacketHistoryDifference(t *testing.T) { hist := newSentPacketHistory(true) hist.SentPacket(0, &packet{}) hist.SentPacket(1, ackElicitingPacket()) hist.SentPacket(2, ackElicitingPacket()) hist.SentPacket(3, ackElicitingPacket()) hist.SkippedPacket(4) hist.SkippedPacket(5) hist.SentPacket(6, ackElicitingPacket()) hist.SentPacket(7, &packet{}) hist.SkippedPacket(8) hist.SentPacket(9, ackElicitingPacket()) require.Zero(t, hist.Difference(1, 1)) require.Zero(t, hist.Difference(2, 2)) require.Zero(t, hist.Difference(7, 7)) require.Equal(t, protocol.PacketNumber(1), hist.Difference(2, 1)) require.Equal(t, protocol.PacketNumber(2), hist.Difference(3, 1)) require.Equal(t, protocol.PacketNumber(3), hist.Difference(4, 1)) require.Equal(t, protocol.PacketNumber(3), hist.Difference(6, 1)) // 4 and 5 were skipped require.Equal(t, protocol.PacketNumber(4), hist.Difference(7, 1)) // 4 and 5 were skipped require.Equal(t, protocol.PacketNumber(3), hist.Difference(7, 2)) // 4 and 5 were skipped require.Equal(t, protocol.PacketNumber(5), hist.Difference(9, 1)) // 4, 5 and 8 were skipped } quic-go-0.59.0/internal/congestion/000077500000000000000000000000001513066070600171415ustar00rootroot00000000000000quic-go-0.59.0/internal/congestion/bandwidth.go000066400000000000000000000010511513066070600214310ustar00rootroot00000000000000package congestion import ( "time" "github.com/quic-go/quic-go/internal/protocol" ) // Bandwidth of a connection type Bandwidth uint64 const ( // BitsPerSecond is 1 bit per second BitsPerSecond Bandwidth = 1 // BytesPerSecond is 1 byte per second BytesPerSecond = 8 * BitsPerSecond ) // BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth { return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond } quic-go-0.59.0/internal/congestion/bandwidth_test.go000066400000000000000000000003311513066070600224700ustar00rootroot00000000000000package congestion import ( "testing" "time" "github.com/stretchr/testify/require" ) func TestBandwidthFromDelta(t *testing.T) { require.Equal(t, 1000*BytesPerSecond, BandwidthFromDelta(1, time.Millisecond)) } quic-go-0.59.0/internal/congestion/clock.go000066400000000000000000000006071513066070600205660ustar00rootroot00000000000000package congestion import ( "github.com/quic-go/quic-go/internal/monotime" ) // A Clock returns the current time type Clock interface { Now() monotime.Time } // DefaultClock implements the Clock interface using the Go stdlib clock. type DefaultClock struct{} var _ Clock = DefaultClock{} // Now gets the current time func (DefaultClock) Now() monotime.Time { return monotime.Now() } quic-go-0.59.0/internal/congestion/cubic.go000066400000000000000000000204531513066070600205610ustar00rootroot00000000000000package congestion import ( "math" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" ) // This cubic implementation is based on the one found in Chromiums's QUIC // implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. // Constants based on TCP defaults. // The following constants are in 2^10 fractions of a second instead of ms to // allow a 10 shift right to divide. // 1024*1024^3 (first 1024 is from 0.100^3) // where 0.100 is 100 ms which is the scaling round trip time. const ( cubeScale = 40 cubeCongestionWindowScale = 410 cubeFactor = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize // TODO: when re-enabling cubic, make sure to use the actual packet size here maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize) ) const defaultNumConnections = 1 // Default Cubic backoff factor const beta float32 = 0.7 // Additional backoff factor when loss occurs in the concave part of the Cubic // curve. This additional backoff factor is expected to give up bandwidth to // new concurrent flows and speed up convergence. const betaLastMax float32 = 0.85 // Cubic implements the cubic algorithm from TCP type Cubic struct { clock Clock // Number of connections to simulate. numConnections int // Time when this cycle started, after last loss event. epoch monotime.Time // Max congestion window used just before last loss event. // Note: to improve fairness to other streams an additional back off is // applied to this value if the new value is below our latest value. lastMaxCongestionWindow protocol.ByteCount // Number of acked bytes since the cycle started (epoch). ackedBytesCount protocol.ByteCount // TCP Reno equivalent congestion window in packets. estimatedTCPcongestionWindow protocol.ByteCount // Origin point of cubic function. originPointCongestionWindow protocol.ByteCount // Time to origin point of cubic function in 2^10 fractions of a second. timeToOriginPoint uint32 // Last congestion window in packets computed by cubic function. lastTargetCongestionWindow protocol.ByteCount } // NewCubic returns a new Cubic instance func NewCubic(clock Clock) *Cubic { c := &Cubic{ clock: clock, numConnections: defaultNumConnections, } c.Reset() return c } // Reset is called after a timeout to reset the cubic state func (c *Cubic) Reset() { c.epoch = 0 c.lastMaxCongestionWindow = 0 c.ackedBytesCount = 0 c.estimatedTCPcongestionWindow = 0 c.originPointCongestionWindow = 0 c.timeToOriginPoint = 0 c.lastTargetCongestionWindow = 0 } func (c *Cubic) alpha() float32 { // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. // We derive the equivalent alpha for an N-connection emulation as: b := c.beta() return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) } func (c *Cubic) beta() float32 { // kNConnectionBeta is the backoff factor after loss for our N-connection // emulation, which emulates the effective backoff of an ensemble of N // TCP-Reno connections on a single loss event. The effective multiplier is // computed as: return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) } func (c *Cubic) betaLastMax() float32 { // betaLastMax is the additional backoff factor after loss for our // N-connection emulation, which emulates the additional backoff of // an ensemble of N TCP-Reno connections on a single loss event. The // effective multiplier is computed as: return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) } // OnApplicationLimited is called on ack arrival when sender is unable to use // the available congestion window. Resets Cubic state during quiescence. func (c *Cubic) OnApplicationLimited() { // When sender is not using the available congestion window, the window does // not grow. But to be RTT-independent, Cubic assumes that the sender has been // using the entire window during the time since the beginning of the current // "epoch" (the end of the last loss recovery period). Since // application-limited periods break this assumption, we reset the epoch when // in such a period. This reset effectively freezes congestion window growth // through application-limited periods and allows Cubic growth to continue // when the entire window is being used. c.epoch = 0 } // CongestionWindowAfterPacketLoss computes a new congestion window to use after // a loss event. Returns the new congestion window in packets. The new // congestion window is a multiplicative decrease of our current window. func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount { if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { // We never reached the old max, so assume we are competing with another // flow. Use our extra back off factor to allow the other flow to go up. c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) } else { c.lastMaxCongestionWindow = currentCongestionWindow } c.epoch = 0 // Reset time. return protocol.ByteCount(float32(currentCongestionWindow) * c.beta()) } // CongestionWindowAfterAck computes a new congestion window to use after a received ACK. // Returns the new congestion window in packets. The new congestion window // follows a cubic function that depends on the time passed since last // packet loss. func (c *Cubic) CongestionWindowAfterAck( ackedBytes protocol.ByteCount, currentCongestionWindow protocol.ByteCount, delayMin time.Duration, eventTime monotime.Time, ) protocol.ByteCount { c.ackedBytesCount += ackedBytes if c.epoch.IsZero() { // First ACK after a loss event. c.epoch = eventTime // Start of epoch. c.ackedBytesCount = ackedBytes // Reset count. // Reset estimated_tcp_congestion_window_ to be in sync with cubic. c.estimatedTCPcongestionWindow = currentCongestionWindow if c.lastMaxCongestionWindow <= currentCongestionWindow { c.timeToOriginPoint = 0 c.originPointCongestionWindow = currentCongestionWindow } else { c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) c.originPointCongestionWindow = c.lastMaxCongestionWindow } } // Change the time unit from microseconds to 2^10 fractions per second. Take // the round trip time in account. This is done to allow us to use shift as a // divide operator. elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) // Right-shifts of negative, signed numbers have implementation-dependent // behavior, so force the offset to be positive, as is done in the kernel. offset := int64(c.timeToOriginPoint) - elapsedTime if offset < 0 { offset = -offset } deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale var targetCongestionWindow protocol.ByteCount if elapsedTime > int64(c.timeToOriginPoint) { targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow } else { targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow } // Limit the CWND increase to half the acked bytes. targetCongestionWindow = min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) // Increase the window by approximately Alpha * 1 MSS of bytes every // time we ack an estimated tcp window of bytes. For small // congestion windows (less than 25), the formula below will // increase slightly slower than linearly per estimated tcp window // of bytes. c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) c.ackedBytesCount = 0 // We have a new cubic congestion window. c.lastTargetCongestionWindow = targetCongestionWindow // Compute target congestion_window based on cubic target and estimated TCP // congestion_window, use highest (fastest). if targetCongestionWindow < c.estimatedTCPcongestionWindow { targetCongestionWindow = c.estimatedTCPcongestionWindow } return targetCongestionWindow } // SetNumConnections sets the number of emulated connections func (c *Cubic) SetNumConnections(n int) { c.numConnections = n } quic-go-0.59.0/internal/congestion/cubic_sender.go000066400000000000000000000240451513066070600221220ustar00rootroot00000000000000package congestion import ( "fmt" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) const ( // maxDatagramSize is the default maximum packet size used in the Linux TCP implementation. // Used in QUIC for congestion window computations in bytes. initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize) maxBurstPackets = 3 renoBeta = 0.7 // Reno backoff factor. minCongestionWindowPackets = 2 initialCongestionWindow = 32 ) type cubicSender struct { hybridSlowStart HybridSlowStart rttStats *utils.RTTStats connStats *utils.ConnectionStats cubic *Cubic pacer *pacer clock Clock reno bool // Track the largest packet that has been sent. largestSentPacketNumber protocol.PacketNumber // Track the largest packet that has been acked. largestAckedPacketNumber protocol.PacketNumber // Track the largest packet number outstanding when a CWND cutback occurs. largestSentAtLastCutback protocol.PacketNumber // Whether the last loss event caused us to exit slowstart. // Used for stats collection of slowstartPacketsLost lastCutbackExitedSlowstart bool // Congestion window in bytes. congestionWindow protocol.ByteCount // Slow start congestion window in bytes, aka ssthresh. slowStartThreshold protocol.ByteCount // ACK counter for the Reno implementation. numAckedPackets uint64 initialCongestionWindow protocol.ByteCount initialMaxCongestionWindow protocol.ByteCount maxDatagramSize protocol.ByteCount lastState qlog.CongestionState qlogger qlogwriter.Recorder } var ( _ SendAlgorithm = &cubicSender{} _ SendAlgorithmWithDebugInfos = &cubicSender{} ) // NewCubicSender makes a new cubic sender func NewCubicSender( clock Clock, rttStats *utils.RTTStats, connStats *utils.ConnectionStats, initialMaxDatagramSize protocol.ByteCount, reno bool, qlogger qlogwriter.Recorder, ) *cubicSender { return newCubicSender( clock, rttStats, connStats, reno, initialMaxDatagramSize, initialCongestionWindow*initialMaxDatagramSize, protocol.MaxCongestionWindowPackets*initialMaxDatagramSize, qlogger, ) } func newCubicSender( clock Clock, rttStats *utils.RTTStats, connStats *utils.ConnectionStats, reno bool, initialMaxDatagramSize, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, qlogger qlogwriter.Recorder, ) *cubicSender { c := &cubicSender{ rttStats: rttStats, connStats: connStats, largestSentPacketNumber: protocol.InvalidPacketNumber, largestAckedPacketNumber: protocol.InvalidPacketNumber, largestSentAtLastCutback: protocol.InvalidPacketNumber, initialCongestionWindow: initialCongestionWindow, initialMaxCongestionWindow: initialMaxCongestionWindow, congestionWindow: initialCongestionWindow, slowStartThreshold: protocol.MaxByteCount, cubic: NewCubic(clock), clock: clock, reno: reno, qlogger: qlogger, maxDatagramSize: initialMaxDatagramSize, } c.pacer = newPacer(c.BandwidthEstimate) if c.qlogger != nil { c.lastState = qlog.CongestionStateSlowStart c.qlogger.RecordEvent(qlog.CongestionStateUpdated{ State: qlog.CongestionStateSlowStart, }) } return c } // TimeUntilSend returns when the next packet should be sent. func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) monotime.Time { return c.pacer.TimeUntilSend() } func (c *cubicSender) HasPacingBudget(now monotime.Time) bool { return c.pacer.Budget(now) >= c.maxDatagramSize } func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { return c.maxDatagramSize * protocol.MaxCongestionWindowPackets } func (c *cubicSender) minCongestionWindow() protocol.ByteCount { return c.maxDatagramSize * minCongestionWindowPackets } func (c *cubicSender) OnPacketSent( sentTime monotime.Time, _ protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool, ) { c.pacer.SentPacket(sentTime, bytes) if !isRetransmittable { return } c.largestSentPacketNumber = packetNumber c.hybridSlowStart.OnPacketSent(packetNumber) } func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool { return bytesInFlight < c.GetCongestionWindow() } func (c *cubicSender) InRecovery() bool { return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback } func (c *cubicSender) InSlowStart() bool { return c.GetCongestionWindow() < c.slowStartThreshold } func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { return c.congestionWindow } func (c *cubicSender) MaybeExitSlowStart() { if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { // exit slow start c.slowStartThreshold = c.congestionWindow c.maybeQlogStateChange(qlog.CongestionStateCongestionAvoidance) } } func (c *cubicSender) OnPacketAcked( ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime monotime.Time, ) { c.largestAckedPacketNumber = max(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { return } c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) if c.InSlowStart() { c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) } } func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { c.connStats.PacketsLost.Add(1) c.connStats.BytesLost.Add(uint64(lostBytes)) // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { return } c.lastCutbackExitedSlowstart = c.InSlowStart() c.maybeQlogStateChange(qlog.CongestionStateRecovery) if c.reno { c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta) } else { c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) } if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { c.congestionWindow = minCwnd } c.slowStartThreshold = c.congestionWindow c.largestSentAtLastCutback = c.largestSentPacketNumber // reset packet count from congestion avoidance mode. We start // counting again when we're out of recovery. c.numAckedPackets = 0 } // Called when we receive an ack. Normal TCP tracks how many packets one ack // represents, but quic has a separate ack for each packet. func (c *cubicSender) maybeIncreaseCwnd( _ protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime monotime.Time, ) { // Do not increase the congestion window unless the sender is close to using // the current window. if !c.isCwndLimited(priorInFlight) { c.cubic.OnApplicationLimited() c.maybeQlogStateChange(qlog.CongestionStateApplicationLimited) return } if c.congestionWindow >= c.maxCongestionWindow() { return } if c.InSlowStart() { // TCP slow start, exponential growth, increase by one for each ACK. c.congestionWindow += c.maxDatagramSize c.maybeQlogStateChange(qlog.CongestionStateSlowStart) return } // Congestion avoidance c.maybeQlogStateChange(qlog.CongestionStateCongestionAvoidance) if c.reno { // Classic Reno congestion avoidance. c.numAckedPackets++ if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { c.congestionWindow += c.maxDatagramSize c.numAckedPackets = 0 } } else { c.congestionWindow = min( c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime), ) } } func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool { congestionWindow := c.GetCongestionWindow() if bytesInFlight >= congestionWindow { return true } availableBytes := congestionWindow - bytesInFlight slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize } // BandwidthEstimate returns the current bandwidth estimate func (c *cubicSender) BandwidthEstimate() Bandwidth { srtt := c.rttStats.SmoothedRTT() if srtt == 0 { // This should never happen, but if it does, avoid division by zero. srtt = protocol.TimerGranularity } return BandwidthFromDelta(c.GetCongestionWindow(), srtt) } // OnRetransmissionTimeout is called on an retransmission timeout func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { c.largestSentAtLastCutback = protocol.InvalidPacketNumber if !packetsRetransmitted { return } c.hybridSlowStart.Restart() c.cubic.Reset() c.slowStartThreshold = c.congestionWindow / 2 c.congestionWindow = c.minCongestionWindow() } // OnConnectionMigration is called when the connection is migrated (?) func (c *cubicSender) OnConnectionMigration() { c.hybridSlowStart.Restart() c.largestSentPacketNumber = protocol.InvalidPacketNumber c.largestAckedPacketNumber = protocol.InvalidPacketNumber c.largestSentAtLastCutback = protocol.InvalidPacketNumber c.lastCutbackExitedSlowstart = false c.cubic.Reset() c.numAckedPackets = 0 c.congestionWindow = c.initialCongestionWindow c.slowStartThreshold = c.initialMaxCongestionWindow } func (c *cubicSender) maybeQlogStateChange(new qlog.CongestionState) { if c.qlogger == nil || new == c.lastState { return } c.qlogger.RecordEvent(qlog.CongestionStateUpdated{State: new}) c.lastState = new } func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) { if s < c.maxDatagramSize { panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) } cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() c.maxDatagramSize = s if cwndIsMinCwnd { c.congestionWindow = c.minCongestionWindow() } c.pacer.SetMaxDatagramSize(s) } quic-go-0.59.0/internal/congestion/cubic_sender_test.go000066400000000000000000000501701513066070600231570ustar00rootroot00000000000000package congestion import ( "fmt" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/stretchr/testify/require" ) const ( initialCongestionWindowPackets = 10 defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * maxDatagramSize ) type mockClock monotime.Time func (c *mockClock) Now() monotime.Time { return monotime.Time(*c) } func (c *mockClock) Advance(d time.Duration) { *c = mockClock(monotime.Time(*c).Add(d)) } const MaxCongestionWindow = 200 * maxDatagramSize type testCubicSender struct { sender *cubicSender clock *mockClock rttStats *utils.RTTStats bytesInFlight protocol.ByteCount packetNumber protocol.PacketNumber ackedPacketNumber protocol.PacketNumber } func newTestCubicSender(cubic bool) *testCubicSender { var clock mockClock rttStats := utils.RTTStats{} return &testCubicSender{ clock: &clock, rttStats: &rttStats, packetNumber: 1, sender: newCubicSender( &clock, &rttStats, &utils.ConnectionStats{}, !cubic, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil, ), } } func (s *testCubicSender) SendAvailableSendWindowLen(packetLength protocol.ByteCount) int { var packetsSent int for s.sender.CanSend(s.bytesInFlight) { s.sender.OnPacketSent(s.clock.Now(), s.bytesInFlight, s.packetNumber, packetLength, true) s.packetNumber++ packetsSent++ s.bytesInFlight += packetLength } return packetsSent } func (s *testCubicSender) AckNPackets(n int) { s.rttStats.UpdateRTT(60*time.Millisecond, 0) s.sender.MaybeExitSlowStart() for range n { s.ackedPacketNumber++ s.sender.OnPacketAcked(s.ackedPacketNumber, maxDatagramSize, s.bytesInFlight, s.clock.Now()) } s.bytesInFlight -= protocol.ByteCount(n) * maxDatagramSize s.clock.Advance(time.Millisecond) } func (s *testCubicSender) LoseNPacketsLen(n int, packetLength protocol.ByteCount) { for range n { s.ackedPacketNumber++ s.sender.OnCongestionEvent(s.ackedPacketNumber, packetLength, s.bytesInFlight) } s.bytesInFlight -= protocol.ByteCount(n) * packetLength } func (s *testCubicSender) LosePacket(number protocol.PacketNumber) { s.sender.OnCongestionEvent(number, maxDatagramSize, s.bytesInFlight) s.bytesInFlight -= maxDatagramSize } func (s *testCubicSender) SendAvailableSendWindow() int { return s.SendAvailableSendWindowLen(maxDatagramSize) } func (s *testCubicSender) LoseNPackets(n int) { s.LoseNPacketsLen(n, maxDatagramSize) } func TestCubicSenderStartup(t *testing.T) { sender := newTestCubicSender(false) // At startup make sure we are at the default. require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) // Make sure we can send. require.Zero(t, sender.sender.TimeUntilSend(0)) require.True(t, sender.sender.CanSend(sender.bytesInFlight)) // And that window is un-affected. require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) // Fill the send window with data, then verify that we can't send. sender.SendAvailableSendWindow() require.False(t, sender.sender.CanSend(sender.bytesInFlight)) } func TestCubicSenderPacing(t *testing.T) { sender := newTestCubicSender(false) // Set up RTT and advance clock sender.rttStats.UpdateRTT(10*time.Millisecond, 0) sender.clock.Advance(time.Hour) // Fill the send window with data, then verify that we can't send. sender.SendAvailableSendWindow() sender.AckNPackets(1) // Check that we can't send immediately due to pacing delay := sender.sender.TimeUntilSend(sender.bytesInFlight) require.NotZero(t, delay) require.Less(t, delay.Sub(monotime.Time(*sender.clock)), time.Hour) } func TestCubicSenderApplicationLimitedSlowStart(t *testing.T) { sender := newTestCubicSender(false) // At startup make sure we can send. require.True(t, sender.sender.CanSend(0)) require.Zero(t, sender.sender.TimeUntilSend(0)) // Send exactly 10 packets and ensure the CWND ends at 14 packets. const numberOfAcks = 5 sender.SendAvailableSendWindow() for range numberOfAcks { sender.AckNPackets(2) } bytesToSend := sender.sender.GetCongestionWindow() // It's expected 2 acks will arrive when the bytes_in_flight are greater than // half the CWND. require.Equal(t, defaultWindowTCP+maxDatagramSize*2*2, bytesToSend) } func TestCubicSenderExponentialSlowStart(t *testing.T) { sender := newTestCubicSender(false) // At startup make sure we can send. require.True(t, sender.sender.CanSend(0)) require.Zero(t, sender.sender.TimeUntilSend(0)) const numberOfAcks = 20 for range numberOfAcks { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) } cwnd := sender.sender.GetCongestionWindow() require.Equal(t, defaultWindowTCP+maxDatagramSize*2*numberOfAcks, cwnd) require.Equal(t, BandwidthFromDelta(cwnd, sender.rttStats.SmoothedRTT()), sender.sender.BandwidthEstimate()) } func TestCubicSenderSlowStartPacketLoss(t *testing.T) { sender := newTestCubicSender(false) const numberOfAcks = 10 for range numberOfAcks { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) } sender.SendAvailableSendWindow() expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Lose a packet to exit slow start. sender.LoseNPackets(1) packetsInRecoveryWindow := expectedSendWindow / maxDatagramSize // We should now have fallen out of slow start with a reduced window. expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Recovery phase. We need to ack every packet in the recovery window before // we exit recovery. numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize sender.AckNPackets(int(packetsInRecoveryWindow)) sender.SendAvailableSendWindow() require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // We need to ack an entire window before we increase CWND by 1. fmt.Println(numberOfPacketsInWindow) sender.AckNPackets(int(numberOfPacketsInWindow) - 2) sender.SendAvailableSendWindow() fmt.Println(sender.clock.Now()) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Next ack should increase cwnd by 1. sender.AckNPackets(1) expectedSendWindow += maxDatagramSize require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Now RTO and ensure slow start gets reset. require.True(t, sender.sender.hybridSlowStart.Started()) sender.sender.OnRetransmissionTimeout(true) require.False(t, sender.sender.hybridSlowStart.Started()) } func TestCubicSenderSlowStartPacketLossPRR(t *testing.T) { sender := newTestCubicSender(false) // Test based on the first example in RFC6937. // Ack 10 packets in 5 acks to raise the CWND to 20, as in the example. const numberOfAcks = 5 for range numberOfAcks { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) } sender.SendAvailableSendWindow() expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) sender.LoseNPackets(1) // We should now have fallen out of slow start with a reduced window. sendWindowBeforeLoss := expectedSendWindow expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Testing TCP proportional rate reduction. // We should send packets paced over the received acks for the remaining // outstanding packets. The number of packets before we exit recovery is the // original CWND minus the packet that has been lost and the one which // triggered the loss. remainingPacketsInRecovery := sendWindowBeforeLoss/maxDatagramSize - 2 for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ { sender.AckNPackets(1) sender.SendAvailableSendWindow() require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) } // We need to ack another window before we increase CWND by 1. numberOfPacketsInWindow := expectedSendWindow / maxDatagramSize for range numberOfPacketsInWindow { sender.AckNPackets(1) require.Equal(t, 1, sender.SendAvailableSendWindow()) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) } sender.AckNPackets(1) expectedSendWindow += maxDatagramSize require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) } func TestCubicSenderSlowStartBurstPacketLossPRR(t *testing.T) { sender := newTestCubicSender(false) // Test based on the second example in RFC6937, though we also implement // forward acknowledgements, so the first two incoming acks will trigger // PRR immediately. // Ack 20 packets in 10 acks to raise the CWND to 30. const numberOfAcks = 10 for range numberOfAcks { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) } sender.SendAvailableSendWindow() expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Lose one more than the congestion window reduction, so that after loss, // bytes_in_flight is lesser than the congestion window. sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow)) numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/maxDatagramSize + 1 sender.LoseNPackets(int(numPacketsToLose)) // Immediately after the loss, ensure at least one packet can be sent. // Losses without subsequent acks can occur with timer based loss detection. require.True(t, sender.sender.CanSend(sender.bytesInFlight)) sender.AckNPackets(1) // We should now have fallen out of slow start with a reduced window. expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Only 2 packets should be allowed to be sent, per PRR-SSRB require.Equal(t, 2, sender.SendAvailableSendWindow()) // Ack the next packet, which triggers another loss. sender.LoseNPackets(1) sender.AckNPackets(1) // Send 2 packets to simulate PRR-SSRB. require.Equal(t, 2, sender.SendAvailableSendWindow()) // Ack the next packet, which triggers another loss. sender.LoseNPackets(1) sender.AckNPackets(1) // Send 2 packets to simulate PRR-SSRB. require.Equal(t, 2, sender.SendAvailableSendWindow()) // Exit recovery and return to sending at the new rate. for range numberOfAcks { sender.AckNPackets(1) require.Equal(t, 1, sender.SendAvailableSendWindow()) } } func TestCubicSenderRTOCongestionWindow(t *testing.T) { sender := newTestCubicSender(false) require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) require.Equal(t, protocol.MaxByteCount, sender.sender.slowStartThreshold) // Expect the window to decrease to the minimum once the RTO fires // and slow start threshold to be set to 1/2 of the CWND. sender.sender.OnRetransmissionTimeout(true) require.Equal(t, 2*maxDatagramSize, sender.sender.GetCongestionWindow()) require.Equal(t, 5*maxDatagramSize, sender.sender.slowStartThreshold) } func TestCubicSenderTCPCubicResetEpochOnQuiescence(t *testing.T) { sender := newTestCubicSender(true) const maxCongestionWindow = 50 const maxCongestionWindowBytes = maxCongestionWindow * maxDatagramSize numSent := sender.SendAvailableSendWindow() // Make sure we fall out of slow start. savedCwnd := sender.sender.GetCongestionWindow() sender.LoseNPackets(1) require.Greater(t, savedCwnd, sender.sender.GetCongestionWindow()) // Ack the rest of the outstanding packets to get out of recovery. for i := 1; i < numSent; i++ { sender.AckNPackets(1) } require.Zero(t, sender.bytesInFlight) // Send a new window of data and ack all; cubic growth should occur. savedCwnd = sender.sender.GetCongestionWindow() numSent = sender.SendAvailableSendWindow() for range numSent { sender.AckNPackets(1) } require.Less(t, savedCwnd, sender.sender.GetCongestionWindow()) require.Greater(t, maxCongestionWindowBytes, sender.sender.GetCongestionWindow()) require.Zero(t, sender.bytesInFlight) // Quiescent time of 100 seconds sender.clock.Advance(100 * time.Second) // Send new window of data and ack one packet. Cubic epoch should have // been reset; ensure cwnd increase is not dramatic. savedCwnd = sender.sender.GetCongestionWindow() sender.SendAvailableSendWindow() sender.AckNPackets(1) require.InDelta(t, float64(savedCwnd), float64(sender.sender.GetCongestionWindow()), float64(maxDatagramSize)) require.Greater(t, maxCongestionWindowBytes, sender.sender.GetCongestionWindow()) } func TestCubicSenderMultipleLossesInOneWindow(t *testing.T) { sender := newTestCubicSender(false) sender.SendAvailableSendWindow() initialWindow := sender.sender.GetCongestionWindow() sender.LosePacket(sender.ackedPacketNumber + 1) postLossWindow := sender.sender.GetCongestionWindow() require.True(t, initialWindow > postLossWindow) sender.LosePacket(sender.ackedPacketNumber + 3) require.Equal(t, postLossWindow, sender.sender.GetCongestionWindow()) sender.LosePacket(sender.packetNumber - 1) require.Equal(t, postLossWindow, sender.sender.GetCongestionWindow()) // Lose a later packet and ensure the window decreases. sender.LosePacket(sender.packetNumber) require.True(t, postLossWindow > sender.sender.GetCongestionWindow()) } func TestCubicSender1ConnectionCongestionAvoidanceAtEndOfRecovery(t *testing.T) { sender := newTestCubicSender(false) // Ack 10 packets in 5 acks to raise the CWND to 20. const numberOfAcks = 5 for range numberOfAcks { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) } sender.SendAvailableSendWindow() expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) sender.LoseNPackets(1) // We should now have fallen out of slow start with a reduced window. expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // No congestion window growth should occur in recovery phase, i.e., until the // currently outstanding 20 packets are acked. for range 10 { // Send our full send window. sender.SendAvailableSendWindow() require.True(t, sender.sender.InRecovery()) sender.AckNPackets(2) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) } require.False(t, sender.sender.InRecovery()) // Out of recovery now. Congestion window should not grow during RTT. for i := protocol.ByteCount(0); i < expectedSendWindow/maxDatagramSize-2; i += 2 { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) } // Next ack should cause congestion window to grow by 1MSS. sender.SendAvailableSendWindow() sender.AckNPackets(2) expectedSendWindow += maxDatagramSize require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) } func TestCubicSenderNoPRR(t *testing.T) { sender := newTestCubicSender(false) sender.SendAvailableSendWindow() sender.LoseNPackets(9) sender.AckNPackets(1) require.Equal(t, protocol.ByteCount(renoBeta*float32(defaultWindowTCP)), sender.sender.GetCongestionWindow()) windowInPackets := int(renoBeta * float32(defaultWindowTCP) / float32(maxDatagramSize)) numSent := sender.SendAvailableSendWindow() require.Equal(t, windowInPackets, numSent) } func TestCubicSenderResetAfterConnectionMigration(t *testing.T) { sender := newTestCubicSender(false) require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) require.Equal(t, protocol.MaxByteCount, sender.sender.slowStartThreshold) // Starts with slow start. const numberOfAcks = 10 for range numberOfAcks { // Send our full send window. sender.SendAvailableSendWindow() sender.AckNPackets(2) } sender.SendAvailableSendWindow() expectedSendWindow := defaultWindowTCP + (maxDatagramSize * 2 * numberOfAcks) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) // Loses a packet to exit slow start. sender.LoseNPackets(1) // We should now have fallen out of slow start with a reduced window. Slow // start threshold is also updated. expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta) require.Equal(t, expectedSendWindow, sender.sender.GetCongestionWindow()) require.Equal(t, expectedSendWindow, sender.sender.slowStartThreshold) // Resets cwnd and slow start threshold on connection migrations. sender.sender.OnConnectionMigration() require.Equal(t, defaultWindowTCP, sender.sender.GetCongestionWindow()) require.Equal(t, MaxCongestionWindow, sender.sender.slowStartThreshold) require.False(t, sender.sender.hybridSlowStart.Started()) } func TestCubicSenderSlowStartsUpToMaximumCongestionWindow(t *testing.T) { var clock mockClock rttStats := utils.RTTStats{} const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize sender := newCubicSender( &clock, &rttStats, &utils.ConnectionStats{}, true, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil, ) for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { sender.MaybeExitSlowStart() sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now()) } require.Equal(t, initialMaxCongestionWindow, sender.GetCongestionWindow()) } func TestCubicSenderMaximumPacketSizeReduction(t *testing.T) { sender := newTestCubicSender(false) require.Panics(t, func() { sender.sender.SetMaxDatagramSize(initialMaxDatagramSize - 1) }) } func TestCubicSenderSlowStartsPacketSizeIncrease(t *testing.T) { var clock mockClock rttStats := utils.RTTStats{} const initialMaxCongestionWindow = protocol.MaxCongestionWindowPackets * initialMaxDatagramSize sender := newCubicSender( &clock, &rttStats, &utils.ConnectionStats{}, true, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, initialMaxCongestionWindow, nil, ) const packetSize = initialMaxDatagramSize + 100 sender.SetMaxDatagramSize(packetSize) for i := 1; i < protocol.MaxCongestionWindowPackets; i++ { sender.OnPacketAcked(protocol.PacketNumber(i), packetSize, sender.GetCongestionWindow(), clock.Now()) } const maxCwnd = protocol.MaxCongestionWindowPackets * packetSize require.True(t, sender.GetCongestionWindow() > maxCwnd) require.True(t, sender.GetCongestionWindow() <= maxCwnd+packetSize) } func TestCubicSenderLimitCwndIncreaseInCongestionAvoidance(t *testing.T) { // Enable Cubic. var clock mockClock rttStats := utils.RTTStats{} sender := newCubicSender( &clock, &rttStats, &utils.ConnectionStats{}, false, protocol.InitialPacketSize, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow, nil, ) testSender := &testCubicSender{ sender: sender, clock: &clock, rttStats: &rttStats, } numSent := testSender.SendAvailableSendWindow() // Make sure we fall out of slow start. savedCwnd := sender.GetCongestionWindow() testSender.LoseNPackets(1) require.Greater(t, savedCwnd, sender.GetCongestionWindow()) // Ack the rest of the outstanding packets to get out of recovery. for i := 1; i < numSent; i++ { testSender.AckNPackets(1) } require.Equal(t, protocol.ByteCount(0), testSender.bytesInFlight) savedCwnd = sender.GetCongestionWindow() testSender.SendAvailableSendWindow() // Ack packets until the CWND increases. for sender.GetCongestionWindow() == savedCwnd { testSender.AckNPackets(1) testSender.SendAvailableSendWindow() } // Bytes in flight may be larger than the CWND if the CWND isn't an exact // multiple of the packet sizes being sent. require.GreaterOrEqual(t, testSender.bytesInFlight, sender.GetCongestionWindow()) savedCwnd = sender.GetCongestionWindow() // Advance time 2 seconds waiting for an ack. clock.Advance(2 * time.Second) // Ack two packets. The CWND should increase by only one packet. testSender.AckNPackets(2) require.Equal(t, savedCwnd+maxDatagramSize, sender.GetCongestionWindow()) } quic-go-0.59.0/internal/congestion/cubic_test.go000066400000000000000000000175751513066070600216330ustar00rootroot00000000000000package congestion import ( "math" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) const ( numConnections uint32 = 2 nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections) nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections) nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta) maxCubicTimeInterval = 30 * time.Millisecond ) func renoCwnd(currentCwnd protocol.ByteCount) protocol.ByteCount { return currentCwnd + protocol.ByteCount(float32(maxDatagramSize)*nConnectionAlpha*float32(maxDatagramSize)/float32(currentCwnd)) } func cubicConvexCwnd(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount { offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000 deltaCongestionWindow := 410 * offset * offset * offset * maxDatagramSize >> 40 return initialCwnd + deltaCongestionWindow } func TestCubicAboveOriginWithTighterBounds(t *testing.T) { var clock mockClock cubic := NewCubic(&clock) cubic.SetNumConnections(int(numConnections)) // Convex growth. const rttMin = 100 * time.Millisecond const rttMinS = float32(rttMin/time.Millisecond) / 1000.0 currentCwnd := 10 * maxDatagramSize initialCwnd := currentCwnd clock.Advance(time.Millisecond) initialTime := clock.Now() expectedFirstCwnd := renoCwnd(currentCwnd) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, initialTime) require.Equal(t, expectedFirstCwnd, currentCwnd) // Normal TCP phase. // The maximum number of expected reno RTTs can be calculated by // finding the point where the cubic curve and the reno curve meet. maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2) for range maxRenoRtts { numAcksThisEpoch := int(float32(currentCwnd/maxDatagramSize) / nConnectionAlpha) initialCwndThisEpoch := currentCwnd for range numAcksThisEpoch { // Call once per ACK. expectedNextCwnd := renoCwnd(currentCwnd) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) require.Equal(t, expectedNextCwnd, currentCwnd) } cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch require.InDelta(t, float64(maxDatagramSize), float64(cwndChangeThisEpoch), float64(maxDatagramSize)/2) clock.Advance(100 * time.Millisecond) } for range 54 { maxAcksThisEpoch := currentCwnd / maxDatagramSize interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond for range int(maxAcksThisEpoch) { clock.Advance(interval) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) require.Equal(t, expectedCwnd, currentCwnd) } } expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) require.Equal(t, expectedCwnd, currentCwnd) } func TestCubicAboveOriginWithFineGrainedCubing(t *testing.T) { var clock mockClock cubic := NewCubic(&clock) cubic.SetNumConnections(int(numConnections)) currentCwnd := 1000 * maxDatagramSize initialCwnd := currentCwnd rttMin := 100 * time.Millisecond clock.Advance(time.Millisecond) initialTime := clock.Now() currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) clock.Advance(600 * time.Millisecond) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) for i := 0; i < 100; i++ { clock.Advance(10 * time.Millisecond) expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime)) nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) require.Equal(t, expectedCwnd, nextCwnd) require.Greater(t, nextCwnd, currentCwnd) cwndDelta := nextCwnd - currentCwnd require.Less(t, cwndDelta, maxDatagramSize/10) currentCwnd = nextCwnd } } func TestCubicHandlesPerAckUpdates(t *testing.T) { var clock mockClock cubic := NewCubic(&clock) cubic.SetNumConnections(int(numConnections)) initialCwndPackets := 150 currentCwnd := protocol.ByteCount(initialCwndPackets) * maxDatagramSize rttMin := 350 * time.Millisecond clock.Advance(time.Millisecond) rCwnd := renoCwnd(currentCwnd) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) initialCwnd := currentCwnd maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha) interval := maxCubicTimeInterval / time.Duration(maxAcks+1) clock.Advance(interval) rCwnd = renoCwnd(rCwnd) require.Equal(t, currentCwnd, cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())) for range maxAcks - 1 { clock.Advance(interval) nextCwnd := cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) rCwnd = renoCwnd(rCwnd) require.Greater(t, nextCwnd, currentCwnd) require.Equal(t, rCwnd, nextCwnd) currentCwnd = nextCwnd } minimumExpectedIncrease := maxDatagramSize * 9 / 10 require.Greater(t, currentCwnd, initialCwnd+minimumExpectedIncrease) } func TestCubicHandlesLossEvents(t *testing.T) { var clock mockClock cubic := NewCubic(&clock) cubic.SetNumConnections(int(numConnections)) rttMin := 100 * time.Millisecond currentCwnd := 422 * maxDatagramSize expectedCwnd := renoCwnd(currentCwnd) clock.Advance(time.Millisecond) require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())) preLossCwnd := currentCwnd require.Zero(t, cubic.lastMaxCongestionWindow) expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) require.Equal(t, preLossCwnd, cubic.lastMaxCongestionWindow) currentCwnd = expectedCwnd preLossCwnd = currentCwnd expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) currentCwnd = expectedCwnd require.Greater(t, preLossCwnd, cubic.lastMaxCongestionWindow) expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax) require.Equal(t, expectedLastMax, cubic.lastMaxCongestionWindow) require.Less(t, expectedCwnd, cubic.lastMaxCongestionWindow) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) require.Greater(t, cubic.lastMaxCongestionWindow, currentCwnd) currentCwnd = cubic.lastMaxCongestionWindow - 1 preLossCwnd = currentCwnd expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) expectedLastMax = preLossCwnd require.Equal(t, expectedLastMax, cubic.lastMaxCongestionWindow) } func TestCubicBelowOrigin(t *testing.T) { var clock mockClock cubic := NewCubic(&clock) cubic.SetNumConnections(int(numConnections)) rttMin := 100 * time.Millisecond currentCwnd := 422 * maxDatagramSize expectedCwnd := renoCwnd(currentCwnd) clock.Advance(time.Millisecond) require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now())) expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta) require.Equal(t, expectedCwnd, cubic.CongestionWindowAfterPacketLoss(currentCwnd)) currentCwnd = expectedCwnd currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) for range 40 { clock.Advance(100 * time.Millisecond) currentCwnd = cubic.CongestionWindowAfterAck(maxDatagramSize, currentCwnd, rttMin, clock.Now()) } expectedCwnd = 553632 * maxDatagramSize / 1460 require.Equal(t, expectedCwnd, currentCwnd) } quic-go-0.59.0/internal/congestion/hybrid_slow_start.go000066400000000000000000000076521513066070600232440ustar00rootroot00000000000000package congestion import ( "time" "github.com/quic-go/quic-go/internal/protocol" ) // Note(pwestin): the magic clamping numbers come from the original code in // tcp_cubic.c. const hybridStartLowWindow = protocol.ByteCount(16) // Number of delay samples for detecting the increase of delay. const hybridStartMinSamples = uint32(8) // Exit slow start if the min rtt has increased by more than 1/8th. const hybridStartDelayFactorExp = 3 // 2^3 = 8 // The original paper specifies 2 and 8ms, but those have changed over time. const ( hybridStartDelayMinThresholdUs = int64(4000) hybridStartDelayMaxThresholdUs = int64(16000) ) // HybridSlowStart implements the TCP hybrid slow start algorithm type HybridSlowStart struct { endPacketNumber protocol.PacketNumber lastSentPacketNumber protocol.PacketNumber started bool currentMinRTT time.Duration rttSampleCount uint32 hystartFound bool } // StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) { s.endPacketNumber = lastSent s.currentMinRTT = 0 s.rttSampleCount = 0 s.started = true } // IsEndOfRound returns true if this ack is the last packet number of our current slow start round. func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool { return s.endPacketNumber < ack } // ShouldExitSlowStart should be called on every new ack frame, since a new // RTT measurement can be made then. // rtt: the RTT for this ack packet. // minRTT: is the lowest delay (RTT) we have seen during the session. // congestionWindow: the congestion window in packets. func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool { if !s.started { // Time to start the hybrid slow start. s.StartReceiveRound(s.lastSentPacketNumber) } if s.hystartFound { return true } // Second detection parameter - delay increase detection. // Compare the minimum delay (s.currentMinRTT) of the current // burst of packets relative to the minimum delay during the session. // Note: we only look at the first few(8) packets in each burst, since we // only want to compare the lowest RTT of the burst relative to previous // bursts. s.rttSampleCount++ if s.rttSampleCount <= hybridStartMinSamples { if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { s.currentMinRTT = latestRTT } } // We only need to check this once per round. if s.rttSampleCount == hybridStartMinSamples { // Divide minRTT by 8 to get a rtt increase threshold for exiting. minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) // Ensure the rtt threshold is never less than 2ms or more than 16ms. minRTTincreaseThresholdUs = min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) minRTTincreaseThreshold := time.Duration(max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { s.hystartFound = true } } // Exit from slow start if the cwnd is greater than 16 and // increasing delay is found. return congestionWindow >= hybridStartLowWindow && s.hystartFound } // OnPacketSent is called when a packet was sent func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) { s.lastSentPacketNumber = packetNumber } // OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end // the round when the final packet of the burst is received and start it on // the next incoming ack. func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) { if s.IsEndOfRound(ackedPacketNumber) { s.started = false } } // Started returns true if started func (s *HybridSlowStart) Started() bool { return s.started } // Restart the slow start phase func (s *HybridSlowStart) Restart() { s.started = false s.hystartFound = false } quic-go-0.59.0/internal/congestion/hybrid_slow_start_test.go000066400000000000000000000041031513066070600242670ustar00rootroot00000000000000package congestion import ( "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestHybridSlowStartSimpleCase(t *testing.T) { slowStart := HybridSlowStart{} packetNumber := protocol.PacketNumber(1) endPacketNumber := protocol.PacketNumber(3) slowStart.StartReceiveRound(endPacketNumber) packetNumber++ require.False(t, slowStart.IsEndOfRound(packetNumber)) // Test duplicates. require.False(t, slowStart.IsEndOfRound(packetNumber)) packetNumber++ require.False(t, slowStart.IsEndOfRound(packetNumber)) packetNumber++ require.True(t, slowStart.IsEndOfRound(packetNumber)) // Test without a new registered end_packet_number; packetNumber++ require.True(t, slowStart.IsEndOfRound(packetNumber)) endPacketNumber = 20 slowStart.StartReceiveRound(endPacketNumber) for packetNumber < endPacketNumber { packetNumber++ require.False(t, slowStart.IsEndOfRound(packetNumber)) } packetNumber++ require.True(t, slowStart.IsEndOfRound(packetNumber)) } func TestHybridSlowStartWithDelay(t *testing.T) { slowStart := HybridSlowStart{} const rtt = 60 * time.Millisecond // We expect to detect the increase at +1/8 of the RTT; hence at a typical // RTT of 60ms the detection will happen at 67.5 ms. const hybridStartMinSamples = 8 // Number of acks required to trigger. endPacketNumber := protocol.PacketNumber(1) endPacketNumber++ slowStart.StartReceiveRound(endPacketNumber) // Will not trigger since our lowest RTT in our burst is the same as the long // term RTT provided. for n := 0; n < hybridStartMinSamples; n++ { require.False(t, slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)) } endPacketNumber++ slowStart.StartReceiveRound(endPacketNumber) for n := 1; n < hybridStartMinSamples; n++ { require.False(t, slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)) } // Expect to trigger since all packets in this burst was above the long term // RTT provided. require.True(t, slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)) } quic-go-0.59.0/internal/congestion/interface.go000066400000000000000000000021221513066070600214250ustar00rootroot00000000000000package congestion import ( "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" ) // A SendAlgorithm performs congestion control type SendAlgorithm interface { TimeUntilSend(bytesInFlight protocol.ByteCount) monotime.Time HasPacingBudget(now monotime.Time) bool OnPacketSent(sentTime monotime.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) CanSend(bytesInFlight protocol.ByteCount) bool MaybeExitSlowStart() OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime monotime.Time) OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) OnRetransmissionTimeout(packetsRetransmitted bool) SetMaxDatagramSize(protocol.ByteCount) } // A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos type SendAlgorithmWithDebugInfos interface { SendAlgorithm InSlowStart() bool InRecovery() bool GetCongestionWindow() protocol.ByteCount } quic-go-0.59.0/internal/congestion/pacer.go000066400000000000000000000064371513066070600205740ustar00rootroot00000000000000package congestion import ( "math" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" ) const maxBurstSizePackets = 10 // The pacer implements a token bucket pacing algorithm. type pacer struct { budgetAtLastSent protocol.ByteCount maxDatagramSize protocol.ByteCount lastSentTime monotime.Time adjustedBandwidth func() uint64 // in bytes/s } func newPacer(getBandwidth func() Bandwidth) *pacer { p := &pacer{ maxDatagramSize: initialMaxDatagramSize, adjustedBandwidth: func() uint64 { // Bandwidth is in bits/s. We need the value in bytes/s. bw := uint64(getBandwidth() / BytesPerSecond) // Use a slightly higher value than the actual measured bandwidth. // RTT variations then won't result in under-utilization of the congestion window. // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. return bw * 5 / 4 }, } p.budgetAtLastSent = p.maxBurstSize() return p } func (p *pacer) SentPacket(sendTime monotime.Time, size protocol.ByteCount) { budget := p.Budget(sendTime) if size >= budget { p.budgetAtLastSent = 0 } else { p.budgetAtLastSent = budget - size } p.lastSentTime = sendTime } func (p *pacer) Budget(now monotime.Time) protocol.ByteCount { if p.lastSentTime.IsZero() { return p.maxBurstSize() } delta := now.Sub(p.lastSentTime) var added protocol.ByteCount if delta > 0 { added = p.timeScaledBandwidth(uint64(delta.Nanoseconds())) } budget := p.budgetAtLastSent + added if added > 0 && budget < p.budgetAtLastSent { budget = protocol.MaxByteCount } return min(p.maxBurstSize(), budget) } func (p *pacer) maxBurstSize() protocol.ByteCount { return max( p.timeScaledBandwidth(uint64((protocol.MinPacingDelay + protocol.TimerGranularity).Nanoseconds())), maxBurstSizePackets*p.maxDatagramSize, ) } // timeScaledBandwidth calculates the number of bytes that may be sent within // a given time interval (ns nanoseconds), based on the current bandwidth estimate. // It caps the scaled value to the maximum allowed burst and handles overflows. func (p *pacer) timeScaledBandwidth(ns uint64) protocol.ByteCount { bw := p.adjustedBandwidth() if bw == 0 { return 0 } const nsPerSecond = 1e9 maxBurst := maxBurstSizePackets * p.maxDatagramSize var scaled protocol.ByteCount if ns > math.MaxUint64/bw { scaled = maxBurst } else { scaled = protocol.ByteCount(bw * ns / nsPerSecond) } return scaled } // TimeUntilSend returns when the next packet should be sent. // It returns zero if a packet can be sent immediately. func (p *pacer) TimeUntilSend() monotime.Time { if p.budgetAtLastSent >= p.maxDatagramSize { return 0 } diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) bw := p.adjustedBandwidth() // We might need to round up this value. // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. d := diff / bw // this is effectively a math.Ceil, but using only integer math if diff%bw > 0 { d++ } return p.lastSentTime.Add(max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond)) } func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) { p.maxDatagramSize = s } quic-go-0.59.0/internal/congestion/pacer_test.go000066400000000000000000000120261513066070600216220ustar00rootroot00000000000000package congestion import ( "math" "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestPacerPacing(t *testing.T) { bandwidth := 50 * initialMaxDatagramSize // 50 full-size packets per second p := newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) now := monotime.Now() require.Zero(t, p.TimeUntilSend()) budget := p.Budget(now) require.Equal(t, maxBurstSizePackets*initialMaxDatagramSize, budget) // consume the initial budget by sending packets for budget > 0 { require.Zero(t, p.TimeUntilSend()) require.Equal(t, budget, p.Budget(now)) p.SentPacket(now, initialMaxDatagramSize) budget -= initialMaxDatagramSize } // now packets are being paced for range 5 { require.Zero(t, p.Budget(now)) nextPacket := p.TimeUntilSend() require.NotZero(t, nextPacket) require.Equal(t, time.Second/50, nextPacket.Sub(now)) now = nextPacket p.SentPacket(now, initialMaxDatagramSize) } nextPacket := p.TimeUntilSend() require.Equal(t, time.Second/50, nextPacket.Sub(now)) // send this packet a bit later, simulating timer delay p.SentPacket(nextPacket.Add(time.Millisecond), initialMaxDatagramSize) // the next packet should be paced again, without a delay require.Equal(t, time.Second/50, p.TimeUntilSend().Sub(nextPacket)) // now send a half-size packet now = p.TimeUntilSend() p.SentPacket(now, initialMaxDatagramSize/2) require.Equal(t, initialMaxDatagramSize/2, p.Budget(now)) require.Equal(t, time.Second/100, p.TimeUntilSend().Sub(now)) p.SentPacket(p.TimeUntilSend(), initialMaxDatagramSize/2) now = p.TimeUntilSend() // budget accumulates if no packets are sent for a while // we should have accumulated budget to send a burst now require.Equal(t, 5*initialMaxDatagramSize, p.Budget(now.Add(4*time.Second/50))) // but the budget is capped at the max burst size require.Equal(t, maxBurstSizePackets*initialMaxDatagramSize, p.Budget(now.Add(time.Hour))) p.SentPacket(now, initialMaxDatagramSize) require.Zero(t, p.Budget(now)) // reduce the bandwidth bandwidth = 10 * initialMaxDatagramSize // 10 full-size packets per second require.Equal(t, time.Second/10, p.TimeUntilSend().Sub(now)) } func TestPacerUpdatePacketSize(t *testing.T) { const bandwidth = 50 * initialMaxDatagramSize // 50 full-size packets per second p := newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) // consume the initial budget by sending packets now := monotime.Now() for p.Budget(now) > 0 { p.SentPacket(now, initialMaxDatagramSize) } require.Equal(t, time.Second/50, p.TimeUntilSend().Sub(now)) // Double the packet size. We now need to wait twice as long to send the next packet. const newDatagramSize = 2 * initialMaxDatagramSize p.SetMaxDatagramSize(newDatagramSize) require.Equal(t, 2*time.Second/50, p.TimeUntilSend().Sub(now)) // check that the maximum burst size is updated require.Equal(t, maxBurstSizePackets*newDatagramSize, p.Budget(now.Add(time.Hour))) } func TestPacerFastPacing(t *testing.T) { const bandwidth = 10000 * initialMaxDatagramSize // 10,000 full-size packets per second p := newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) // consume the initial budget by sending packets now := monotime.Now() for p.Budget(now) > 0 { p.SentPacket(now, initialMaxDatagramSize) } // If we were pacing by packet, we'd expect the next packet to send in 1/10ms. // However, we don't want to arm the pacing timer for less than 1ms, // so we wait for 1ms, and then send 10 packets in a burst. require.Equal(t, time.Millisecond, p.TimeUntilSend().Sub(now)) require.Equal(t, 10*initialMaxDatagramSize, p.Budget(now.Add(time.Millisecond))) now = now.Add(time.Millisecond) for range 10 { require.NotZero(t, p.Budget(now)) p.SentPacket(now, initialMaxDatagramSize) } require.Zero(t, p.Budget(now)) require.Equal(t, time.Millisecond, p.TimeUntilSend().Sub(now)) } func TestPacerNoOverflows(t *testing.T) { p := newPacer(func() Bandwidth { return math.MaxUint64 }) now := monotime.Now() p.SentPacket(now, initialMaxDatagramSize) for range 100000 { require.NotZero(t, p.Budget(now.Add(time.Duration(rand.Int64N(math.MaxInt64))))) } burstCount := 1 for p.Budget(now) > 0 { burstCount++ p.SentPacket(now, initialMaxDatagramSize) } require.Equal(t, maxBurstSizePackets, burstCount) require.Zero(t, p.Budget(now)) next := p.TimeUntilSend() require.Equal(t, next.Sub(now), protocol.MinPacingDelay) require.Greater(t, p.Budget(next), initialMaxDatagramSize) } func BenchmarkPacer(b *testing.B) { const bandwidth = 50 * initialMaxDatagramSize // 50 full-size packets per second p := newPacer(func() Bandwidth { return Bandwidth(bandwidth) * BytesPerSecond * 4 / 5 }) now := monotime.Now() var i int for b.Loop() { i++ for p.Budget(now) > 0 { p.SentPacket(now, initialMaxDatagramSize) } next := p.TimeUntilSend() if i%2 == 0 { now = next } else { now = now.Add(100 * time.Millisecond) } } } quic-go-0.59.0/internal/flowcontrol/000077500000000000000000000000001513066070600173415ustar00rootroot00000000000000quic-go-0.59.0/internal/flowcontrol/base_flow_controller.go000066400000000000000000000074461513066070600241070ustar00rootroot00000000000000package flowcontrol import ( "sync" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) type baseFlowController struct { // for sending data bytesSent protocol.ByteCount sendWindow protocol.ByteCount lastBlockedAt protocol.ByteCount // for receiving data //nolint:structcheck // The mutex is used both by the stream and the connection flow controller mutex sync.Mutex bytesRead protocol.ByteCount highestReceived protocol.ByteCount receiveWindow protocol.ByteCount receiveWindowSize protocol.ByteCount maxReceiveWindowSize protocol.ByteCount allowWindowIncrease func(size protocol.ByteCount) bool epochStartTime monotime.Time epochStartOffset protocol.ByteCount rttStats *utils.RTTStats logger utils.Logger } // IsNewlyBlocked says if it is newly blocked by flow control. // For every offset, it only returns true once. // If it is blocked, the offset is returned. func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { if c.SendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { return false, 0 } c.lastBlockedAt = c.sendWindow return true, c.sendWindow } func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { c.bytesSent += n } // UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame. func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) { if offset > c.sendWindow { c.sendWindow = offset return true } return false } func (c *baseFlowController) SendWindowSize() protocol.ByteCount { // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters if c.bytesSent > c.sendWindow { return 0 } return c.sendWindow - c.bytesSent } // needs to be called with locked mutex func (c *baseFlowController) addBytesRead(n protocol.ByteCount) { c.bytesRead += n } func (c *baseFlowController) hasWindowUpdate() bool { bytesRemaining := c.receiveWindow - c.bytesRead // update the window when more than the threshold was consumed return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold)) } // getWindowUpdate updates the receive window, if necessary // it returns the new offset func (c *baseFlowController) getWindowUpdate(now monotime.Time) protocol.ByteCount { if !c.hasWindowUpdate() { return 0 } c.maybeAdjustWindowSize(now) c.receiveWindow = c.bytesRead + c.receiveWindowSize return c.receiveWindow } // maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. // For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. func (c *baseFlowController) maybeAdjustWindowSize(now monotime.Time) { bytesReadInEpoch := c.bytesRead - c.epochStartOffset // don't do anything if less than half the window has been consumed if bytesReadInEpoch <= c.receiveWindowSize/2 { return } rtt := c.rttStats.SmoothedRTT() if rtt == 0 { return } fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { // window is consumed too fast, try to increase the window size newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize) if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { c.receiveWindowSize = newSize } } c.startNewAutoTuningEpoch(now) } func (c *baseFlowController) startNewAutoTuningEpoch(now monotime.Time) { c.epochStartTime = now c.epochStartOffset = c.bytesRead } func (c *baseFlowController) checkFlowControlViolation() bool { return c.highestReceived > c.receiveWindow } quic-go-0.59.0/internal/flowcontrol/connection_flow_controller.go000066400000000000000000000071421513066070600253250ustar00rootroot00000000000000package flowcontrol import ( "errors" "fmt" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" ) type connectionFlowController struct { baseFlowController } var _ ConnectionFlowController = &connectionFlowController{} // NewConnectionFlowController gets a new flow controller for the connection // It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0. func NewConnectionFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, allowWindowIncrease func(size protocol.ByteCount) bool, rttStats *utils.RTTStats, logger utils.Logger, ) *connectionFlowController { return &connectionFlowController{ baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, receiveWindowSize: receiveWindow, maxReceiveWindowSize: maxReceiveWindow, allowWindowIncrease: allowWindowIncrease, logger: logger, }, } } // IncrementHighestReceived adds an increment to the highestReceived value func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount, now monotime.Time) error { c.mutex.Lock() defer c.mutex.Unlock() // If this is the first frame received on this connection, start flow-control auto-tuning. if c.highestReceived == 0 { c.startNewAutoTuningEpoch(now) } c.highestReceived += increment if c.checkFlowControlViolation() { return &qerr.TransportError{ ErrorCode: qerr.FlowControlError, ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow), } } return nil } func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) { c.mutex.Lock() defer c.mutex.Unlock() c.addBytesRead(n) return c.hasWindowUpdate() } func (c *connectionFlowController) GetWindowUpdate(now monotime.Time) protocol.ByteCount { c.mutex.Lock() defer c.mutex.Unlock() oldWindowSize := c.receiveWindowSize offset := c.getWindowUpdate(now) if c.logger.Debug() && oldWindowSize < c.receiveWindowSize { c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) } return offset } // EnsureMinimumWindowSize sets a minimum window size // it should make sure that the connection-level window is increased when a stream-level window grows func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount, now monotime.Time) { c.mutex.Lock() defer c.mutex.Unlock() if inc <= c.receiveWindowSize { return } newSize := min(inc, c.maxReceiveWindowSize) if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { c.receiveWindowSize = newSize if c.logger.Debug() { c.logger.Debugf("Increasing receive flow control window for the connection to %d, in response to stream flow control window increase", newSize) } } c.startNewAutoTuningEpoch(now) } // Reset rests the flow controller. This happens when 0-RTT is rejected. // All stream data is invalidated, it's as if we had never opened a stream and never sent any data. // At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet. func (c *connectionFlowController) Reset() error { c.mutex.Lock() defer c.mutex.Unlock() if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() { return errors.New("flow controller reset after reading data") } c.bytesSent = 0 c.lastBlockedAt = 0 c.sendWindow = 0 return nil } quic-go-0.59.0/internal/flowcontrol/connection_flow_controller_test.go000066400000000000000000000050041513066070600263570ustar00rootroot00000000000000package flowcontrol import ( "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/stretchr/testify/require" ) func TestConnectionFlowControlWindowUpdate(t *testing.T) { fc := NewConnectionFlowController( 100, // initial receive window 100, // max receive window nil, utils.NewRTTStats(), utils.DefaultLogger, ) require.False(t, fc.AddBytesRead(1)) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) require.True(t, fc.AddBytesRead(99)) require.Equal(t, protocol.ByteCount(200), fc.GetWindowUpdate(monotime.Now())) } func TestConnectionWindowAutoTuningNotAllowed(t *testing.T) { // the RTT is 1 second rttStats := utils.NewRTTStats() rttStats.UpdateRTT(time.Second, 0) require.Equal(t, time.Second, rttStats.SmoothedRTT()) callbackCalledWith := protocol.InvalidByteCount fc := NewConnectionFlowController( 100, // initial receive window 150, // max receive window func(size protocol.ByteCount) bool { callbackCalledWith = size return false }, rttStats, utils.DefaultLogger, ) now := monotime.Now() require.NoError(t, fc.IncrementHighestReceived(100, now)) fc.AddBytesRead(90) require.Equal(t, protocol.InvalidByteCount, callbackCalledWith) require.Equal(t, protocol.ByteCount(90+100), fc.GetWindowUpdate(now.Add(time.Millisecond))) require.Equal(t, protocol.ByteCount(150-100), callbackCalledWith) } func TestConnectionFlowControlViolation(t *testing.T) { fc := NewConnectionFlowController(100, 100, nil, utils.NewRTTStats(), utils.DefaultLogger) require.NoError(t, fc.IncrementHighestReceived(40, monotime.Now())) require.NoError(t, fc.IncrementHighestReceived(60, monotime.Now())) err := fc.IncrementHighestReceived(1, monotime.Now()) var terr *qerr.TransportError require.ErrorAs(t, err, &terr) require.Equal(t, qerr.FlowControlError, terr.ErrorCode) } func TestConnectionFlowControllerReset(t *testing.T) { fc := NewConnectionFlowController(0, 0, nil, utils.NewRTTStats(), utils.DefaultLogger) fc.UpdateSendWindow(100) fc.AddBytesSent(10) require.Equal(t, protocol.ByteCount(90), fc.SendWindowSize()) require.NoError(t, fc.Reset()) require.Zero(t, fc.SendWindowSize()) } func TestConnectionFlowControllerResetAfterReading(t *testing.T) { fc := NewConnectionFlowController(0, 0, nil, utils.NewRTTStats(), utils.DefaultLogger) fc.AddBytesRead(1) require.EqualError(t, fc.Reset(), "flow controller reset after reading data") } quic-go-0.59.0/internal/flowcontrol/interface.go000066400000000000000000000032251513066070600216320ustar00rootroot00000000000000package flowcontrol import ( "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" ) type flowController interface { // for sending SendWindowSize() protocol.ByteCount UpdateSendWindow(protocol.ByteCount) (updated bool) AddBytesSent(protocol.ByteCount) // for receiving GetWindowUpdate(monotime.Time) protocol.ByteCount // returns 0 if no update is necessary } // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController AddBytesRead(protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) // UpdateHighestReceived is called when a new highest offset is received // final has to be to true if this is the final offset of the stream, // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool, now monotime.Time) error // Abandon is called when reading from the stream is aborted early, // and there won't be any further calls to AddBytesRead. Abandon() IsNewlyBlocked() bool } // The ConnectionFlowController is the flow controller for the connection. type ConnectionFlowController interface { flowController AddBytesRead(protocol.ByteCount) (hasWindowUpdate bool) Reset() error IsNewlyBlocked() (bool, protocol.ByteCount) } type connectionFlowControllerI interface { ConnectionFlowController // The following two methods are not supposed to be called from outside this packet, but are needed internally // for sending EnsureMinimumWindowSize(protocol.ByteCount, monotime.Time) // for receiving IncrementHighestReceived(protocol.ByteCount, monotime.Time) error } quic-go-0.59.0/internal/flowcontrol/stream_flow_controller.go000066400000000000000000000114161513066070600244600ustar00rootroot00000000000000package flowcontrol import ( "fmt" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" ) type streamFlowController struct { baseFlowController streamID protocol.StreamID connection connectionFlowControllerI receivedFinalOffset bool } var _ StreamFlowController = &streamFlowController{} // NewStreamFlowController gets a new flow controller for a stream func NewStreamFlowController( streamID protocol.StreamID, cfc ConnectionFlowController, receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, rttStats *utils.RTTStats, logger utils.Logger, ) StreamFlowController { return &streamFlowController{ streamID: streamID, connection: cfc.(connectionFlowControllerI), baseFlowController: baseFlowController{ rttStats: rttStats, receiveWindow: receiveWindow, receiveWindowSize: receiveWindow, maxReceiveWindowSize: maxReceiveWindow, sendWindow: initialSendWindow, logger: logger, }, } } // UpdateHighestReceived updates the highestReceived value, if the offset is higher. func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now monotime.Time) error { // If the final offset for this stream is already known, check for consistency. if c.receivedFinalOffset { // If we receive another final offset, check that it's the same. if final && offset != c.highestReceived { return &qerr.TransportError{ ErrorCode: qerr.FinalSizeError, ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset), } } // Check that the offset is below the final offset. if offset > c.highestReceived { return &qerr.TransportError{ ErrorCode: qerr.FinalSizeError, ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived), } } } if final { c.receivedFinalOffset = true } if offset == c.highestReceived { return nil } // A higher offset was received before. This can happen due to reordering. if offset < c.highestReceived { if final { return &qerr.TransportError{ ErrorCode: qerr.FinalSizeError, ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived), } } return nil } // If this is the first frame received for this stream, start flow-control auto-tuning. if c.highestReceived == 0 { c.startNewAutoTuningEpoch(now) } increment := offset - c.highestReceived c.highestReceived = offset if c.checkFlowControlViolation() { return &qerr.TransportError{ ErrorCode: qerr.FlowControlError, ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow), } } return c.connection.IncrementHighestReceived(increment, now) } func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) { c.mutex.Lock() c.addBytesRead(n) hasStreamWindowUpdate = c.shouldQueueWindowUpdate() c.mutex.Unlock() hasConnWindowUpdate = c.connection.AddBytesRead(n) return } func (c *streamFlowController) Abandon() { c.mutex.Lock() unread := c.highestReceived - c.bytesRead c.bytesRead = c.highestReceived c.mutex.Unlock() if unread > 0 { c.connection.AddBytesRead(unread) } } func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { c.baseFlowController.AddBytesSent(n) c.connection.AddBytesSent(n) } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { return min(c.baseFlowController.SendWindowSize(), c.connection.SendWindowSize()) } func (c *streamFlowController) IsNewlyBlocked() bool { blocked, _ := c.baseFlowController.IsNewlyBlocked() return blocked } func (c *streamFlowController) shouldQueueWindowUpdate() bool { return !c.receivedFinalOffset && c.hasWindowUpdate() } func (c *streamFlowController) GetWindowUpdate(now monotime.Time) protocol.ByteCount { // If we already received the final offset for this stream, the peer won't need any additional flow control credit. if c.receivedFinalOffset { return 0 } c.mutex.Lock() defer c.mutex.Unlock() oldWindowSize := c.receiveWindowSize offset := c.getWindowUpdate(now) if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size c.logger.Debugf("Increasing receive flow control window for stream %d to %d", c.streamID, c.receiveWindowSize) c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize)*protocol.ConnectionFlowControlMultiplier), now) } return offset } quic-go-0.59.0/internal/flowcontrol/stream_flow_controller_test.go000066400000000000000000000226051513066070600255210ustar00rootroot00000000000000package flowcontrol import ( "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/stretchr/testify/require" ) func TestStreamFlowControlReceiving(t *testing.T) { fc := NewStreamFlowController( 42, NewConnectionFlowController( protocol.MaxByteCount, protocol.MaxByteCount, nil, utils.NewRTTStats(), utils.DefaultLogger, ), 100, protocol.MaxByteCount, protocol.MaxByteCount, utils.NewRTTStats(), utils.DefaultLogger, ) require.NoError(t, fc.UpdateHighestReceived(50, false, monotime.Now())) // duplicates are fine require.NoError(t, fc.UpdateHighestReceived(50, false, monotime.Now())) // reordering is fine require.NoError(t, fc.UpdateHighestReceived(40, false, monotime.Now())) require.NoError(t, fc.UpdateHighestReceived(60, false, monotime.Now())) // exceeding the limit is not fine err := fc.UpdateHighestReceived(101, false, monotime.Now()) var terr *qerr.TransportError require.ErrorAs(t, err, &terr) require.Equal(t, qerr.FlowControlError, terr.ErrorCode) require.Equal(t, "received 101 bytes on stream 42, allowed 100 bytes", terr.ErrorMessage) } func TestStreamFlowControllerFinalOffset(t *testing.T) { newFC := func() StreamFlowController { return NewStreamFlowController( 42, NewConnectionFlowController( protocol.MaxByteCount, protocol.MaxByteCount, nil, utils.NewRTTStats(), utils.DefaultLogger, ), protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, utils.NewRTTStats(), utils.DefaultLogger, ) } t.Run("duplicate final offset", func(t *testing.T) { fc := newFC() require.NoError(t, fc.UpdateHighestReceived(50, true, monotime.Now())) // it is valid to receive the same final offset multiple times require.NoError(t, fc.UpdateHighestReceived(50, true, monotime.Now())) }) t.Run("inconsistent final offset", func(t *testing.T) { fc := newFC() require.NoError(t, fc.UpdateHighestReceived(50, true, monotime.Now())) err := fc.UpdateHighestReceived(51, true, monotime.Now()) require.Error(t, err) var terr *qerr.TransportError require.ErrorAs(t, err, &terr) require.Equal(t, qerr.FinalSizeError, terr.ErrorCode) require.Equal(t, "received inconsistent final offset for stream 42 (old: 50, new: 51 bytes)", terr.ErrorMessage) }) t.Run("non-final offset past final offset", func(t *testing.T) { fc := newFC() require.NoError(t, fc.UpdateHighestReceived(50, true, monotime.Now())) // No matter the ordering, it's never ok to receive an offset past the final offset. err := fc.UpdateHighestReceived(60, false, monotime.Now()) var terr *qerr.TransportError require.ErrorAs(t, err, &terr) require.Equal(t, qerr.FinalSizeError, terr.ErrorCode) require.Equal(t, "received offset 60 for stream 42, but final offset was already received at 50", terr.ErrorMessage) }) t.Run("final offset smaller than previous offset", func(t *testing.T) { fc := newFC() require.NoError(t, fc.UpdateHighestReceived(50, false, monotime.Now())) // If we received offset already, it's invalid to receive a smaller final offset. err := fc.UpdateHighestReceived(40, true, monotime.Now()) var terr *qerr.TransportError require.ErrorAs(t, err, &terr) require.Equal(t, qerr.FinalSizeError, terr.ErrorCode) require.Equal(t, "received final offset 40 for stream 42, but already received offset 50 before", terr.ErrorMessage) }) } func TestStreamAbandoning(t *testing.T) { connFC := NewConnectionFlowController( 100, protocol.MaxByteCount, nil, utils.NewRTTStats(), utils.DefaultLogger, ) require.True(t, connFC.UpdateSendWindow(300)) fc := NewStreamFlowController( 42, connFC, 60, protocol.MaxByteCount, 100, utils.NewRTTStats(), utils.DefaultLogger, ) require.NoError(t, fc.UpdateHighestReceived(50, true, monotime.Now())) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) require.Zero(t, connFC.GetWindowUpdate(monotime.Now())) // Abandon the stream. // This marks all bytes as having been consumed. fc.Abandon() require.Equal(t, protocol.ByteCount(150), connFC.GetWindowUpdate(monotime.Now())) } func TestStreamSendWindow(t *testing.T) { // We set up the connection flow controller with a limit of 300 bytes, // and the stream flow controller with a limit of 100 bytes. connFC := NewConnectionFlowController( protocol.MaxByteCount, protocol.MaxByteCount, nil, utils.NewRTTStats(), utils.DefaultLogger, ) require.True(t, connFC.UpdateSendWindow(300)) fc := NewStreamFlowController( 42, connFC, protocol.MaxByteCount, protocol.MaxByteCount, 100, utils.NewRTTStats(), utils.DefaultLogger, ) // first, we're limited by the stream flow controller require.Equal(t, protocol.ByteCount(100), fc.SendWindowSize()) fc.AddBytesSent(50) require.False(t, fc.IsNewlyBlocked()) require.Equal(t, protocol.ByteCount(50), fc.SendWindowSize()) fc.AddBytesSent(50) require.True(t, fc.IsNewlyBlocked()) require.Zero(t, fc.SendWindowSize()) require.False(t, fc.IsNewlyBlocked()) // we're still blocked, but it's not new // Update the stream flow control limit, but don't update the connection flow control limit. // We're now limited by the connection flow controller. require.True(t, fc.UpdateSendWindow(1000)) // reordered updates are ignored require.False(t, fc.UpdateSendWindow(999)) require.False(t, fc.IsNewlyBlocked()) // we're not blocked anymore require.Equal(t, protocol.ByteCount(200), fc.SendWindowSize()) fc.AddBytesSent(200) require.Zero(t, fc.SendWindowSize()) require.False(t, fc.IsNewlyBlocked()) // we're blocked, but not on stream flow control } func TestStreamWindowUpdate(t *testing.T) { fc := NewStreamFlowController( 42, NewConnectionFlowController( protocol.MaxByteCount, protocol.MaxByteCount, nil, utils.NewRTTStats(), utils.DefaultLogger, ), 100, 100, protocol.MaxByteCount, utils.NewRTTStats(), utils.DefaultLogger, ) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) hasStreamWindowUpdate, _ := fc.AddBytesRead(24) require.False(t, hasStreamWindowUpdate) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) // the window is updated when it's 25% filled hasStreamWindowUpdate, _ = fc.AddBytesRead(1) require.True(t, hasStreamWindowUpdate) require.Equal(t, protocol.ByteCount(125), fc.GetWindowUpdate(monotime.Now())) hasStreamWindowUpdate, _ = fc.AddBytesRead(24) require.False(t, hasStreamWindowUpdate) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) // the window is updated when it's 25% filled hasStreamWindowUpdate, _ = fc.AddBytesRead(1) require.True(t, hasStreamWindowUpdate) require.Equal(t, protocol.ByteCount(150), fc.GetWindowUpdate(monotime.Now())) // Receive the final offset. // We don't need to send any more flow control updates. require.NoError(t, fc.UpdateHighestReceived(100, true, monotime.Now())) fc.AddBytesRead(50) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) } func TestStreamConnectionWindowUpdate(t *testing.T) { connFC := NewConnectionFlowController( 100, protocol.MaxByteCount, nil, utils.NewRTTStats(), utils.DefaultLogger, ) fc := NewStreamFlowController( 42, connFC, 1000, protocol.MaxByteCount, protocol.MaxByteCount, utils.NewRTTStats(), utils.DefaultLogger, ) hasStreamWindowUpdate, hasConnWindowUpdate := fc.AddBytesRead(50) require.False(t, hasStreamWindowUpdate) require.Zero(t, fc.GetWindowUpdate(monotime.Now())) require.True(t, hasConnWindowUpdate) require.NotZero(t, connFC.GetWindowUpdate(monotime.Now())) } func TestStreamWindowAutoTuning(t *testing.T) { // the RTT is 1 second rttStats := utils.NewRTTStats() rttStats.UpdateRTT(time.Second, 0) require.Equal(t, time.Second, rttStats.SmoothedRTT()) connFC := NewConnectionFlowController( 150, // initial receive window 350, // max receive window func(size protocol.ByteCount) bool { return true }, rttStats, utils.DefaultLogger, ) fc := NewStreamFlowController( 42, connFC, 100, // initial send window 399, // max send window protocol.MaxByteCount, rttStats, utils.DefaultLogger, ) now := monotime.Now() require.NoError(t, fc.UpdateHighestReceived(100, false, now)) // data consumption is too slow, window size is not increased now = now.Add(2500 * time.Millisecond) fc.AddBytesRead(51) // one initial stream window size added require.Equal(t, protocol.ByteCount(51+100), fc.GetWindowUpdate(now)) // one initial connection window size added require.Equal(t, protocol.ByteCount(51+150), connFC.getWindowUpdate(now)) // data consumption is fast enough, window size is increased now = now.Add(2 * time.Second) fc.AddBytesRead(51) // stream window size doubled to 200 bytes require.Equal(t, protocol.ByteCount(102+2*100), fc.GetWindowUpdate(now)) // The connection window is now increased as well, // so that we don't get blocked on connection level flow control: // The increase is by 200 bytes * a connection factor of 1.5: 300 bytes. require.Equal(t, protocol.ByteCount(102+300), connFC.GetWindowUpdate(now)) // data consumption is fast enough, window size is increased now = now.Add(2 * time.Second) fc.AddBytesRead(101) // stream window size increased again, but bumps into its maximum value require.Equal(t, protocol.ByteCount(203+399), fc.GetWindowUpdate(now)) // the connection window is also increased, but it bumps into its maximum value require.Equal(t, protocol.ByteCount(203+350), connFC.GetWindowUpdate(now)) } quic-go-0.59.0/internal/handshake/000077500000000000000000000000001513066070600167175ustar00rootroot00000000000000quic-go-0.59.0/internal/handshake/aead.go000066400000000000000000000051501513066070600201410ustar00rootroot00000000000000package handshake import ( "encoding/binary" "github.com/quic-go/quic-go/internal/protocol" ) func createAEAD(suite cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { keyLabel = hkdfLabelKeyV2 ivLabel = hkdfLabelIVV2 } key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) return suite.AEAD(key, iv) } type longHeaderSealer struct { aead *xorNonceAEAD headerProtector headerProtector nonceBuf [8]byte } var _ LongHeaderSealer = &longHeaderSealer{} func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer { if aead.NonceSize() != 8 { panic("unexpected nonce size") } return &longHeaderSealer{ aead: aead, headerProtector: headerProtector, } } func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn)) return s.aead.Seal(dst, s.nonceBuf[:], src, ad) } func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) } func (s *longHeaderSealer) Overhead() int { return s.aead.Overhead() } type longHeaderOpener struct { aead *xorNonceAEAD headerProtector headerProtector highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) // use a single array to avoid allocations nonceBuf [8]byte } var _ LongHeaderOpener = &longHeaderOpener{} func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener { if aead.NonceSize() != 8 { panic("unexpected nonce size") } return &longHeaderOpener{ aead: aead, headerProtector: headerProtector, } } func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) } func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn)) dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad) if err == nil { o.highestRcvdPN = max(o.highestRcvdPN, pn) } else { err = ErrDecryptionFailed } return dec, err } func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) } quic-go-0.59.0/internal/handshake/aead_test.go000066400000000000000000000076041513066070600212060ustar00rootroot00000000000000package handshake import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/tls" "fmt" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func getSealerAndOpener(t *testing.T, cs cipherSuite, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) { t.Helper() key := make([]byte, 16) hpKey := make([]byte, 16) rand.Read(key) rand.Read(hpKey) block, err := aes.NewCipher(key) require.NoError(t, err) aead, err := cipher.NewGCM(block) require.NoError(t, err) return newLongHeaderSealer(&xorNonceAEAD{aead: aead}, newHeaderProtector(cs, hpKey, true, v)), newLongHeaderOpener(&xorNonceAEAD{aead: aead}, newHeaderProtector(cs, hpKey, true, v)) } func TestEncryptAndDecryptMessage(t *testing.T) { for _, v := range []protocol.Version{protocol.Version1, protocol.Version2} { for _, cs := range cipherSuites { t.Run(fmt.Sprintf("QUIC %s/%s", v, tls.CipherSuiteName(cs.ID)), func(t *testing.T) { sealer, opener := getSealerAndOpener(t, cs, v) msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad := []byte("Donec in velit neque.") encrypted := sealer.Seal(nil, msg, 0x1337, ad) opened, err := opener.Open(nil, encrypted, 0x1337, ad) require.NoError(t, err) require.Equal(t, msg, opened) // incorrect associated data _, err = opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) require.Equal(t, ErrDecryptionFailed, err) // incorrect packet number _, err = opener.Open(nil, encrypted, 0x42, ad) require.Equal(t, ErrDecryptionFailed, err) }) } } } func TestDecodePacketNumber(t *testing.T) { msg := []byte("Lorem ipsum dolor sit amet") ad := []byte("Donec in velit neque.") sealer, opener := getSealerAndOpener(t, getCipherSuite(tls.TLS_AES_128_GCM_SHA256), protocol.Version1) encrypted := sealer.Seal(nil, msg, 0x1337, ad) // can't decode the packet number if encryption failed _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) require.Error(t, err) require.Equal(t, protocol.PacketNumber(0x38), opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)) _, err = opener.Open(nil, encrypted, 0x1337, ad) require.NoError(t, err) require.Equal(t, protocol.PacketNumber(0x1338), opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)) } func TestEncryptAndDecryptHeader(t *testing.T) { for _, v := range []protocol.Version{protocol.Version1, protocol.Version2} { t.Run("QUIC "+v.String(), func(t *testing.T) { for _, cs := range cipherSuites { t.Run(tls.CipherSuiteName(cs.ID), func(t *testing.T) { testEncryptAndDecryptHeader(t, cs, v) }) } }) } } func testEncryptAndDecryptHeader(t *testing.T, cs cipherSuite, v protocol.Version) { sealer, opener := getSealerAndOpener(t, cs, v) var lastFourBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) rand.Read(sample) header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} sealer.EncryptHeader(sample, &header[0], header[9:13]) if header[0]&0xf != 0xb5&0xf { lastFourBitsDifferent++ } require.Equal(t, byte(0xb5&0xf0), header[0]&0xf0) require.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8}, header[1:9]) require.NotEqual(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[9:13]) opener.DecryptHeader(sample, &header[0], header[9:13]) require.Equal(t, []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}, header) } require.Greater(t, lastFourBitsDifferent, 75) // decryption failure with different sample header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} sample := make([]byte, 16) rand.Read(sample) sealer.EncryptHeader(sample, &header[0], header[9:13]) rand.Read(sample) // use a different sample opener.DecryptHeader(sample, &header[0], header[9:13]) require.NotEqual(t, []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}, header) } quic-go-0.59.0/internal/handshake/cipher_suite.go000066400000000000000000000051761513066070600217420ustar00rootroot00000000000000package handshake import ( "crypto" "crypto/aes" "crypto/cipher" "crypto/tls" "fmt" "golang.org/x/crypto/chacha20poly1305" ) // These cipher suite implementations are copied from the standard library crypto/tls package. const aeadNonceLength = 12 type cipherSuite struct { ID uint16 Hash crypto.Hash KeyLen int AEAD func(key, nonceMask []byte) *xorNonceAEAD } func (s cipherSuite) IVLen() int { return aeadNonceLength } func getCipherSuite(id uint16) cipherSuite { switch id { case tls.TLS_AES_128_GCM_SHA256: return cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13} case tls.TLS_CHACHA20_POLY1305_SHA256: return cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305} case tls.TLS_AES_256_GCM_SHA384: return cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13} default: panic(fmt.Sprintf("unknown cypher suite: %d", id)) } } func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD { if len(nonceMask) != aeadNonceLength { panic("tls: internal error: wrong nonce length") } aes, err := aes.NewCipher(key) if err != nil { panic(err) } aead, err := cipher.NewGCM(aes) if err != nil { panic(err) } ret := &xorNonceAEAD{aead: aead} copy(ret.nonceMask[:], nonceMask) return ret } func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD { if len(nonceMask) != aeadNonceLength { panic("tls: internal error: wrong nonce length") } aead, err := chacha20poly1305.New(key) if err != nil { panic(err) } ret := &xorNonceAEAD{aead: aead} copy(ret.nonceMask[:], nonceMask) return ret } // xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce // before each call. type xorNonceAEAD struct { nonceMask [aeadNonceLength]byte aead cipher.AEAD } func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { for i, b := range nonce { f.nonceMask[4+i] ^= b } result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) for i, b := range nonce { f.nonceMask[4+i] ^= b } return result } func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { for i, b := range nonce { f.nonceMask[4+i] ^= b } result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) for i, b := range nonce { f.nonceMask[4+i] ^= b } return result, err } quic-go-0.59.0/internal/handshake/crypto_setup.go000066400000000000000000000511321513066070600220100ustar00rootroot00000000000000package handshake import ( "context" "crypto/tls" "errors" "fmt" "net" "strings" "sync/atomic" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/quicvarint" ) type quicVersionContextKey struct{} var QUICVersionContextKey = &quicVersionContextKey{} const clientSessionStateRevision = 5 type cryptoSetup struct { tlsConf *tls.Config conn *tls.QUICConn events []Event version protocol.Version ourParams *wire.TransportParameters peerParams *wire.TransportParameters zeroRTTParameters *wire.TransportParameters allow0RTT bool rttStats *utils.RTTStats qlogger qlogwriter.Recorder logger utils.Logger perspective protocol.Perspective handshakeCompleteTime time.Time zeroRTTOpener LongHeaderOpener // only set for the server zeroRTTSealer LongHeaderSealer // only set for the client initialOpener LongHeaderOpener initialSealer LongHeaderSealer handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer used0RTT atomic.Bool aead *updatableAEAD has1RTTSealer bool has1RTTOpener bool } var _ CryptoSetup = &cryptoSetup{} // NewCryptoSetupClient creates a new crypto setup for the client func NewCryptoSetupClient( connID protocol.ConnectionID, tp *wire.TransportParameters, tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, qlogger qlogwriter.Recorder, logger utils.Logger, version protocol.Version, ) CryptoSetup { cs := newCryptoSetup( connID, tp, rttStats, qlogger, logger, protocol.PerspectiveClient, version, ) tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 cs.tlsConf = tlsConf cs.allow0RTT = enable0RTT cs.conn = tls.QUICClient(&tls.QUICConfig{ TLSConfig: tlsConf, EnableSessionEvents: true, }) cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) return cs } // NewCryptoSetupServer creates a new crypto setup for the server func NewCryptoSetupServer( connID protocol.ConnectionID, localAddr, remoteAddr net.Addr, tp *wire.TransportParameters, tlsConf *tls.Config, allow0RTT bool, rttStats *utils.RTTStats, qlogger qlogwriter.Recorder, logger utils.Logger, version protocol.Version, ) CryptoSetup { cs := newCryptoSetup( connID, tp, rttStats, qlogger, logger, protocol.PerspectiveServer, version, ) cs.allow0RTT = allow0RTT tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr) cs.tlsConf = tlsConf cs.conn = tls.QUICServer(&tls.QUICConfig{ TLSConfig: tlsConf, EnableSessionEvents: true, }) return cs } func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, rttStats *utils.RTTStats, qlogger qlogwriter.Recorder, logger utils.Logger, perspective protocol.Perspective, version protocol.Version, ) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) if qlogger != nil { qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveClient), }) qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveServer), }) } return &cryptoSetup{ initialSealer: initialSealer, initialOpener: initialOpener, aead: newUpdatableAEAD(rttStats, qlogger, logger, version), events: make([]Event, 0, 16), ourParams: tp, rttStats: rttStats, qlogger: qlogger, logger: logger, perspective: perspective, version: version, } } func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) h.initialSealer = initialSealer h.initialOpener = initialOpener if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveClient), }) h.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.EncryptionInitial, protocol.PerspectiveServer), }) } } func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { return h.aead.SetLargestAcked(pn) } func (h *cryptoSetup) StartHandshake(ctx context.Context) error { err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version)) if err != nil { return wrapError(err) } for { ev := h.conn.NextEvent() if err := h.handleEvent(ev); err != nil { return wrapError(err) } if ev.Kind == tls.QUICNoEvent { break } } if h.perspective == protocol.PerspectiveClient { if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { h.logger.Debugf("Doing 0-RTT.") h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters}) } else { h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil) } } return nil } // Close closes the crypto setup. // It aborts the handshake, if it is still running. func (h *cryptoSetup) Close() error { return h.conn.Close() } // HandleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error { if err := h.handleMessage(data, encLevel); err != nil { return wrapError(err) } return nil } func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil { return err } for { ev := h.conn.NextEvent() if err := h.handleEvent(ev); err != nil { return err } if ev.Kind == tls.QUICNoEvent { return nil } } } func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) { switch ev.Kind { case tls.QUICNoEvent: return nil case tls.QUICSetReadSecret: h.setReadKey(ev.Level, ev.Suite, ev.Data) return nil case tls.QUICSetWriteSecret: h.setWriteKey(ev.Level, ev.Suite, ev.Data) return nil case tls.QUICTransportParameters: return h.handleTransportParameters(ev.Data) case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) return nil case tls.QUICRejectedEarlyData: h.rejected0RTT() return nil case tls.QUICWriteData: h.writeRecord(ev.Level, ev.Data) return nil case tls.QUICHandshakeDone: h.handshakeComplete() return nil case tls.QUICStoreSession: if h.perspective == protocol.PerspectiveServer { panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server") } ev.SessionState.Extra = append( ev.SessionState.Extra, addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)), ) return h.conn.StoreSession(ev.SessionState) case tls.QUICResumeSession: var allowEarlyData bool switch h.perspective { case protocol.PerspectiveClient: // for clients, this event occurs when a session ticket is selected allowEarlyData = h.handleDataFromSessionState( findSessionStateExtraData(ev.SessionState.Extra), ev.SessionState.EarlyData, ) case protocol.PerspectiveServer: // for servers, this event occurs when receiving the client's session ticket allowEarlyData = h.handleSessionTicket( findSessionStateExtraData(ev.SessionState.Extra), ev.SessionState.EarlyData, ) } if ev.SessionState.EarlyData { ev.SessionState.EarlyData = allowEarlyData } return nil default: // Unknown events should be ignored. // crypto/tls will ensure that this is safe to do. // See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details. return nil } } func (h *cryptoSetup) NextEvent() Event { if len(h.events) == 0 { return Event{Kind: EventNoEvent} } ev := h.events[0] h.events = h.events[1:] return ev } func (h *cryptoSetup) handleTransportParameters(data []byte) error { var tp wire.TransportParameters if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { return err } h.peerParams = &tp h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams}) return nil } // must be called after receiving the transport parameters func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, clientSessionStateRevision) if earlyData { // only save the transport parameters for 0-RTT enabled session tickets return h.peerParams.MarshalForSessionTicket(b) } return b } func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) { tp, err := decodeDataFromSessionState(data, earlyData) if err != nil { h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) return } // The session ticket might have been saved from a connection that allowed 0-RTT, // and therefore contain transport parameters. // Only use them if 0-RTT is actually used on the new connection. if tp != nil && h.allow0RTT { h.zeroRTTParameters = tp return true } return false } func decodeDataFromSessionState(b []byte, earlyData bool) (*wire.TransportParameters, error) { ver, l, err := quicvarint.Parse(b) if err != nil { return nil, err } b = b[l:] if ver != clientSessionStateRevision { return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) } if !earlyData { return nil, nil } var tp wire.TransportParameters if err := tp.UnmarshalFromSessionTicket(b); err != nil { return nil, err } return &tp, nil } func (h *cryptoSetup) getDataForSessionTicket() []byte { return (&sessionTicket{ Parameters: h.ourParams, }).Marshal() } // GetSessionTicket generates a new session ticket. // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ EarlyData: h.allow0RTT, Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())}, }); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. // See https://github.com/golang/go/issues/62032. // This error assertion can be removed once we drop support for Go 1.25. if strings.Contains(err.Error(), "session ticket keys unavailable") { return nil, nil } return nil, err } // If session tickets are disabled, NextEvent will immediately return QUICNoEvent, // and we will return a nil ticket. var ticket []byte for { ev := h.conn.NextEvent() if ev.Kind == tls.QUICNoEvent { break } if ev.Kind == tls.QUICWriteData && ev.Level == tls.QUICEncryptionLevelApplication { if ticket != nil { h.logger.Errorf("unexpected multiple session tickets") continue } ticket = ev.Data } else { h.logger.Errorf("unexpected event: %v", ev.Kind) } } return ticket, nil } // handleSessionTicket is called for the server when receiving the client's session ticket. // It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT. // Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT: // A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT. func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) { var t sessionTicket if err := t.Unmarshal(data); err != nil { h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error()) return false } if !using0RTT { return false } valid := h.ourParams.ValidFor0RTT(t.Parameters) if !valid { h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") return false } if !h.allow0RTT { h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.") return false } return true } // rejected0RTT is called for the client when the server rejects 0-RTT. func (h *cryptoSetup) rejected0RTT() { h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") had0RTTKeys := h.zeroRTTSealer != nil h.zeroRTTSealer = nil if had0RTTKeys { h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) } } func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } h.zeroRTTOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) h.used0RTT.Store(true) if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } case tls.QUICEncryptionLevelHandshake: h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } case tls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } default: panic("unexpected read encryption level") } h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite()), }) } } func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } h.zeroRTTSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.Encryption0RTT, h.perspective), }) } // don't set used0RTT here. 0-RTT might still get rejected. return case tls.QUICEncryptionLevelHandshake: h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } case tls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } if h.zeroRTTSealer != nil { // Once we receive handshake keys, we know that 0-RTT was not rejected. h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClient0RTT}) } } default: panic("unexpected write encryption level") } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateTLS, KeyType: encLevelToKeyType(protocol.FromTLSEncryptionLevel(el), h.perspective), }) } } // writeRecord is called when TLS writes data func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) { //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { case tls.QUICEncryptionLevelInitial: h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p}) case tls.QUICEncryptionLevelHandshake: h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p}) case tls.QUICEncryptionLevelApplication: panic("unexpected write") default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) } } func (h *cryptoSetup) DiscardInitialKeys() { dropped := h.initialOpener != nil h.initialOpener = nil h.initialSealer = nil if dropped { h.logger.Debugf("Dropping Initial keys.") if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClientInitial}) h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeServerInitial}) } } } func (h *cryptoSetup) handshakeComplete() { h.handshakeCompleteTime = time.Now() h.events = append(h.events, Event{Kind: EventHandshakeComplete}) } func (h *cryptoSetup) SetHandshakeConfirmed() { h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool if h.handshakeOpener != nil { h.handshakeOpener = nil h.handshakeSealer = nil dropped = true } if dropped { h.logger.Debugf("Dropping Handshake keys.") if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClientHandshake}) h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeServerHandshake}) } } } func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { if h.initialSealer == nil { return nil, ErrKeysDropped } return h.initialSealer, nil } func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { if h.zeroRTTSealer == nil { return nil, ErrKeysDropped } return h.zeroRTTSealer, nil } func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { if h.handshakeSealer == nil { if h.initialSealer == nil { return nil, ErrKeysDropped } return nil, ErrKeysNotYetAvailable } return h.handshakeSealer, nil } func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { if !h.has1RTTSealer { return nil, ErrKeysNotYetAvailable } return h.aead, nil } func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { if h.initialOpener == nil { return nil, ErrKeysDropped } return h.initialOpener, nil } func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { if h.zeroRTTOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable } // if the initial opener is also not available, the keys were already dropped return nil, ErrKeysDropped } return h.zeroRTTOpener, nil } func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { if h.handshakeOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable } // if the initial opener is also not available, the keys were already dropped return nil, ErrKeysDropped } return h.handshakeOpener, nil } func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") if h.qlogger != nil { h.qlogger.RecordEvent(qlog.KeyDiscarded{KeyType: qlog.KeyTypeClient0RTT}) } } if !h.has1RTTOpener { return nil, ErrKeysNotYetAvailable } return h.aead, nil } func (h *cryptoSetup) ConnectionState() ConnectionState { return ConnectionState{ ConnectionState: h.conn.ConnectionState(), Used0RTT: h.used0RTT.Load(), } } func wrapError(err error) error { if alertErr := tls.AlertError(0); errors.As(err, &alertErr) { return qerr.NewLocalCryptoError(uint8(alertErr), err) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} } func encLevelToKeyType(encLevel protocol.EncryptionLevel, pers protocol.Perspective) qlog.KeyType { if pers == protocol.PerspectiveServer { switch encLevel { case protocol.EncryptionInitial: return qlog.KeyTypeServerInitial case protocol.EncryptionHandshake: return qlog.KeyTypeServerHandshake case protocol.Encryption0RTT: return qlog.KeyTypeServer0RTT case protocol.Encryption1RTT: return qlog.KeyTypeServer1RTT default: return "" } } switch encLevel { case protocol.EncryptionInitial: return qlog.KeyTypeClientInitial case protocol.EncryptionHandshake: return qlog.KeyTypeClientHandshake case protocol.Encryption0RTT: return qlog.KeyTypeClient0RTT case protocol.Encryption1RTT: return qlog.KeyTypeClient1RTT default: return "" } } quic-go-0.59.0/internal/handshake/crypto_setup_test.go000066400000000000000000000421531513066070600230520ustar00rootroot00000000000000package handshake import ( "context" "crypto/ed25519" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "errors" "math/big" "net" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/testdata" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) const ( typeClientHello = 1 typeNewSessionTicket = 4 ) type mockClientSessionCache struct { cache tls.ClientSessionCache puts chan *tls.ClientSessionState } var _ tls.ClientSessionCache = &mockClientSessionCache{} func newMockClientSessionCache() *mockClientSessionCache { return &mockClientSessionCache{ puts: make(chan *tls.ClientSessionState, 1), cache: tls.NewLRUClientSessionCache(1), } } func (m *mockClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) { return m.cache.Get(sessionKey) } func (m *mockClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { m.puts <- cs m.cache.Put(sessionKey, cs) } func getTLSConfigs() (clientConf, serverConf *tls.Config) { clientConf = &tls.Config{ ServerName: "localhost", RootCAs: testdata.GetRootCA(), NextProtos: []string{"crypto-setup"}, } serverConf = testdata.GetTLSConfig() serverConf.NextProtos = []string{"crypto-setup"} return clientConf, serverConf } func TestErrorBeforeClientHelloGeneration(t *testing.T) { tlsConf := testdata.GetTLSConfig() tlsConf.InsecureSkipVerify = true tlsConf.NextProtos = []string{""} cl := NewCryptoSetupClient( protocol.ConnectionID{}, &wire.TransportParameters{}, tlsConf, false, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("client"), protocol.Version1, ) var terr *qerr.TransportError err := cl.StartHandshake(context.Background()) require.True(t, errors.As(err, &terr)) require.Equal(t, uint64(0x100+0x50), uint64(terr.ErrorCode)) require.Contains(t, err.Error(), "tls: invalid NextProtos value") } func TestMessageReceivedAtWrongEncryptionLevel(t *testing.T) { var token protocol.StatelessResetToken server := NewCryptoSetupServer( protocol.ConnectionID{}, &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, &wire.TransportParameters{StatelessResetToken: &token}, testdata.GetTLSConfig(), false, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) require.NoError(t, server.StartHandshake(context.Background())) fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) // wrong encryption level err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) require.Error(t, err) require.Contains(t, err.Error(), "tls: handshake data received at wrong level") } // The clientEvents and serverEvents contain all events that were not processed by the function, // i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete. func handshake(t *testing.T, client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) { t.Helper() require.NoError(t, client.StartHandshake(context.Background())) require.NoError(t, server.StartHandshake(context.Background())) var clientHandshakeComplete, serverHandshakeComplete bool for { clientLoop: for { ev := client.NextEvent() switch ev.Kind { case EventNoEvent: break clientLoop case EventWriteInitialData: serverErr = server.HandleMessage(ev.Data, protocol.EncryptionInitial) if serverErr != nil { return } case EventWriteHandshakeData: serverErr = server.HandleMessage(ev.Data, protocol.EncryptionHandshake) if serverErr != nil { return } case EventHandshakeComplete: clientHandshakeComplete = true default: clientEvents = append(clientEvents, ev) } } serverLoop: for { ev := server.NextEvent() switch ev.Kind { case EventNoEvent: break serverLoop case EventWriteInitialData: clientErr = client.HandleMessage(ev.Data, protocol.EncryptionInitial) if clientErr != nil { return } case EventWriteHandshakeData: clientErr = client.HandleMessage(ev.Data, protocol.EncryptionHandshake) if clientErr != nil { return } case EventHandshakeComplete: serverHandshakeComplete = true ticket, err := server.GetSessionTicket() require.NoError(t, err) if ticket != nil { require.NoError(t, client.HandleMessage(ticket, protocol.Encryption1RTT)) } default: serverEvents = append(serverEvents, ev) } } if clientHandshakeComplete && serverHandshakeComplete { break } } return } func handshakeWithTLSConf( t *testing.T, clientConf, serverConf *tls.Config, clientRTTStats, serverRTTStats *utils.RTTStats, clientTransportParameters, serverTransportParameters *wire.TransportParameters, enable0RTT bool, ) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */ CryptoSetup /* server */, []Event /* more server events */, error, /* server error */ ) { t.Helper() client := NewCryptoSetupClient( protocol.ConnectionID{}, clientTransportParameters, clientConf, enable0RTT, clientRTTStats, nil, utils.DefaultLogger.WithPrefix("client"), protocol.Version1, ) if serverTransportParameters.StatelessResetToken == nil { var token protocol.StatelessResetToken serverTransportParameters.StatelessResetToken = &token } server := NewCryptoSetupServer( protocol.ConnectionID{}, &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, serverTransportParameters, serverConf, enable0RTT, serverRTTStats, nil, utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) cEvents, cErr, sEvents, sErr := handshake(t, client, server) return client, cEvents, cErr, server, sEvents, sErr } func TestHandshake(t *testing.T) { clientConf, serverConf := getTLSConfigs() _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) } func TestHelloRetryRequest(t *testing.T) { clientConf, serverConf := getTLSConfigs() serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) } func TestWithClientAuth(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{}, SignatureAlgorithm: x509.PureEd25519, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour), BasicConstraintsValid: true, } certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv) require.NoError(t, err) clientCert := tls.Certificate{ PrivateKey: priv, Certificate: [][]byte{certDER}, } clientConf, serverConf := getTLSConfigs() clientConf.Certificates = []tls.Certificate{clientCert} serverConf.ClientAuth = tls.RequireAnyClientCert _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) } func TestTransportParameters(t *testing.T) { clientConf, serverConf := getTLSConfigs() cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second} client := NewCryptoSetupClient( protocol.ConnectionID{}, cTransportParameters, clientConf, false, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("client"), protocol.Version1, ) var token protocol.StatelessResetToken sTransportParameters := &wire.TransportParameters{ MaxIdleTimeout: 1337 * time.Second, StatelessResetToken: &token, ActiveConnectionIDLimit: 2, } server := NewCryptoSetupServer( protocol.ConnectionID{}, &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, sTransportParameters, serverConf, false, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) clientEvents, cErr, serverEvents, sErr := handshake(t, client, server) require.NoError(t, cErr) require.NoError(t, sErr) var clientReceivedTransportParameters *wire.TransportParameters for _, ev := range clientEvents { if ev.Kind == EventReceivedTransportParameters { clientReceivedTransportParameters = ev.TransportParameters } } require.NotNil(t, clientReceivedTransportParameters) require.Equal(t, 1337*time.Second, clientReceivedTransportParameters.MaxIdleTimeout) var serverReceivedTransportParameters *wire.TransportParameters for _, ev := range serverEvents { if ev.Kind == EventReceivedTransportParameters { serverReceivedTransportParameters = ev.TransportParameters } } require.NotNil(t, serverReceivedTransportParameters) require.Equal(t, 42*time.Second, serverReceivedTransportParameters.MaxIdleTimeout) } func TestNewSessionTicketAtWrongEncryptionLevel(t *testing.T) { clientConf, serverConf := getTLSConfigs() client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) // inject an invalid session ticket b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) err := client.HandleMessage(b, protocol.EncryptionHandshake) require.Error(t, err) require.Contains(t, err.Error(), "tls: handshake data received at wrong level") } func TestHandlingNewSessionTicketFails(t *testing.T) { clientConf, serverConf := getTLSConfigs() client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) // inject an invalid session ticket b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) err := client.HandleMessage(b, protocol.Encryption1RTT) require.IsType(t, &qerr.TransportError{}, err) require.True(t, err.(*qerr.TransportError).ErrorCode.IsCryptoError()) } func TestSessionResumption(t *testing.T) { clientConf, serverConf := getTLSConfigs() csc := newMockClientSessionCache() clientConf.ClientSessionCache = csc client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) select { case <-csc.puts: case <-time.After(time.Second): t.Fatal("didn't receive a session ticket") } require.False(t, server.ConnectionState().DidResume) require.False(t, client.ConnectionState().DidResume) clientRTTStats := utils.NewRTTStats() serverRTTStats := utils.NewRTTStats() client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( t, clientConf, serverConf, clientRTTStats, serverRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) select { case <-csc.puts: case <-time.After(time.Second): t.Fatal("didn't receive a session ticket") } require.True(t, server.ConnectionState().DidResume) require.True(t, client.ConnectionState().DidResume) } func TestSessionResumptionDisabled(t *testing.T) { clientConf, serverConf := getTLSConfigs() csc := newMockClientSessionCache() clientConf.ClientSessionCache = csc client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) select { case <-csc.puts: case <-time.After(time.Second): t.Fatal("didn't receive a session ticket") } require.False(t, server.ConnectionState().DidResume) require.False(t, client.ConnectionState().DidResume) serverConf.SessionTicketsDisabled = true client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, ) require.NoError(t, clientErr) require.NoError(t, serverErr) select { case <-csc.puts: t.Fatal("didn't expect to receive a session ticket") case <-time.After(25 * time.Millisecond): } require.False(t, server.ConnectionState().DidResume) require.False(t, client.ConnectionState().DidResume) } func Test0RTT(t *testing.T) { clientConf, serverConf := getTLSConfigs() csc := newMockClientSessionCache() clientConf.ClientSessionCache = csc const initialMaxData protocol.ByteCount = 1337 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, true, ) require.NoError(t, clientErr) require.NoError(t, serverErr) select { case <-csc.puts: case <-time.After(time.Second): t.Fatal("didn't receive a session ticket") } require.False(t, server.ConnectionState().DidResume) require.False(t, client.ConnectionState().DidResume) client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, true, ) require.NoError(t, clientErr) require.NoError(t, serverErr) var tp *wire.TransportParameters var clientReceived0RTTKeys bool for _, ev := range clientEvents { switch ev.Kind { case EventRestoredTransportParameters: tp = ev.TransportParameters case EventReceivedReadKeys: clientReceived0RTTKeys = true } } require.True(t, clientReceived0RTTKeys) require.NotNil(t, tp) require.Equal(t, initialMaxData, tp.InitialMaxData) var serverReceived0RTTKeys bool for _, ev := range serverEvents { switch ev.Kind { case EventReceivedReadKeys: serverReceived0RTTKeys = true } } require.True(t, serverReceived0RTTKeys) require.True(t, server.ConnectionState().DidResume) require.True(t, client.ConnectionState().DidResume) require.True(t, server.ConnectionState().Used0RTT) require.True(t, client.ConnectionState().Used0RTT) } func Test0RTTRejectionOnTransportParametersChanged(t *testing.T) { clientConf, serverConf := getTLSConfigs() csc := newMockClientSessionCache() clientConf.ClientSessionCache = csc const initialMaxData protocol.ByteCount = 1337 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, utils.NewRTTStats(), utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, true, ) require.NoError(t, clientErr) require.NoError(t, serverErr) select { case <-csc.puts: case <-time.After(time.Second): t.Fatal("didn't receive a session ticket") } require.False(t, server.ConnectionState().DidResume) require.False(t, client.ConnectionState().DidResume) clientRTTStats := utils.NewRTTStats() client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf( t, clientConf, serverConf, clientRTTStats, utils.NewRTTStats(), &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1}, true, ) require.NoError(t, clientErr) require.NoError(t, serverErr) var tp *wire.TransportParameters var clientReceived0RTTKeys bool for _, ev := range clientEvents { switch ev.Kind { case EventRestoredTransportParameters: tp = ev.TransportParameters case EventReceivedReadKeys: clientReceived0RTTKeys = true } } require.True(t, clientReceived0RTTKeys) require.NotNil(t, tp) require.Equal(t, initialMaxData, tp.InitialMaxData) require.True(t, server.ConnectionState().DidResume) require.True(t, client.ConnectionState().DidResume) require.False(t, server.ConnectionState().Used0RTT) require.False(t, client.ConnectionState().Used0RTT) } quic-go-0.59.0/internal/handshake/fake_conn.go000066400000000000000000000012251513066070600211710ustar00rootroot00000000000000package handshake import ( "net" "time" ) type conn struct { localAddr, remoteAddr net.Addr } var _ net.Conn = &conn{} func (c *conn) Read([]byte) (int, error) { return 0, nil } func (c *conn) Write([]byte) (int, error) { return 0, nil } func (c *conn) Close() error { return nil } func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } func (c *conn) LocalAddr() net.Addr { return c.localAddr } func (c *conn) SetReadDeadline(time.Time) error { return nil } func (c *conn) SetWriteDeadline(time.Time) error { return nil } func (c *conn) SetDeadline(time.Time) error { return nil } quic-go-0.59.0/internal/handshake/handshake_helpers_test.go000066400000000000000000000015231513066070600237560ustar00rootroot00000000000000package handshake import ( "crypto/tls" "encoding/hex" "strings" "testing" "github.com/stretchr/testify/require" ) func splitHexString(t *testing.T, s string) (slice []byte) { t.Helper() for _, ss := range strings.Split(s, " ") { if ss[0:2] == "0x" { ss = ss[2:] } d, err := hex.DecodeString(ss) require.NoError(t, err) slice = append(slice, d...) } return } func TestSplitHexString(t *testing.T) { require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, splitHexString(t, "0xdeadbeef")) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, splitHexString(t, "deadbeef")) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, splitHexString(t, "dead beef")) } var cipherSuites = []cipherSuite{ getCipherSuite(tls.TLS_AES_128_GCM_SHA256), getCipherSuite(tls.TLS_AES_256_GCM_SHA384), getCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256), } quic-go-0.59.0/internal/handshake/header_protector.go000066400000000000000000000070741513066070600226070ustar00rootroot00000000000000package handshake import ( "crypto/aes" "crypto/cipher" "crypto/tls" "encoding/binary" "fmt" "golang.org/x/crypto/chacha20" "github.com/quic-go/quic-go/internal/protocol" ) type headerProtector interface { EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) } func hkdfHeaderProtectionLabel(v protocol.Version) string { if v == protocol.Version2 { return "quicv2 hp" } return "quic hp" } func newHeaderProtector(suite cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector { hkdfLabel := hkdfHeaderProtectionLabel(v) switch suite.ID { case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) case tls.TLS_CHACHA20_POLY1305_SHA256: return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) default: panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) } } type aesHeaderProtector struct { mask [16]byte // AES always has a 16 byte block size block cipher.Block isLongHeader bool } var _ headerProtector = &aesHeaderProtector{} func newAESHeaderProtector(suite cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) block, err := aes.NewCipher(hpKey) if err != nil { panic(fmt.Sprintf("error creating new AES cipher: %s", err)) } return &aesHeaderProtector{ block: block, isLongHeader: isLongHeader, } } func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { p.apply(sample, firstByte, hdrBytes) } func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { p.apply(sample, firstByte, hdrBytes) } func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { if len(sample) != len(p.mask) { panic("invalid sample size") } p.block.Encrypt(p.mask[:], sample) if p.isLongHeader { *firstByte ^= p.mask[0] & 0xf } else { *firstByte ^= p.mask[0] & 0x1f } for i := range hdrBytes { hdrBytes[i] ^= p.mask[i+1] } } type chachaHeaderProtector struct { mask [5]byte key [32]byte isLongHeader bool } var _ headerProtector = &chachaHeaderProtector{} func newChaChaHeaderProtector(suite cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) p := &chachaHeaderProtector{ isLongHeader: isLongHeader, } copy(p.key[:], hpKey) return p } func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { p.apply(sample, firstByte, hdrBytes) } func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { p.apply(sample, firstByte, hdrBytes) } func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { if len(sample) != 16 { panic("invalid sample size") } for i := 0; i < 5; i++ { p.mask[i] = 0 } cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:]) if err != nil { panic(err) } cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4])) cipher.XORKeyStream(p.mask[:], p.mask[:]) p.applyMask(firstByte, hdrBytes) } func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) { if p.isLongHeader { *firstByte ^= p.mask[0] & 0xf } else { *firstByte ^= p.mask[0] & 0x1f } for i := range hdrBytes { hdrBytes[i] ^= p.mask[i+1] } } quic-go-0.59.0/internal/handshake/hkdf.go000066400000000000000000000013641513066070600201660ustar00rootroot00000000000000package handshake import ( "crypto" "encoding/binary" "golang.org/x/crypto/hkdf" ) // hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1. func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { b := make([]byte, 3, 3+6+len(label)+1+len(context)) binary.BigEndian.PutUint16(b, uint16(length)) b[2] = uint8(6 + len(label)) b = append(b, []byte("tls13 ")...) b = append(b, []byte(label)...) b = b[:3+6+len(label)+1] b[3+6+len(label)] = uint8(len(context)) b = append(b, context...) out := make([]byte, length) n, err := hkdf.Expand(hash.New, secret, b).Read(out) if err != nil || n != length { panic("quic: HKDF-Expand-Label invocation failed unexpectedly") } return out } quic-go-0.59.0/internal/handshake/hkdf_test.go000066400000000000000000000040111513066070600212150ustar00rootroot00000000000000package handshake import ( "crypto" "crypto/cipher" "crypto/rand" "crypto/tls" "testing" "unsafe" "github.com/stretchr/testify/require" ) var tls13CipherSuites = []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} type cipherSuiteTLS13 struct { ID uint16 KeyLen int AEAD func(key, fixedNonce []byte) cipher.AEAD Hash crypto.Hash } //go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 var cipherSuitesTLS13 []unsafe.Pointer func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { for _, v := range cipherSuitesTLS13 { cs := (*cipherSuiteTLS13)(v) if cs.ID == id { return cs } } return nil } //go:linkname nextTrafficSecret crypto/tls.(*cipherSuiteTLS13).nextTrafficSecret func nextTrafficSecret(cs *cipherSuiteTLS13, trafficSecret []byte) []byte func TestHKDF(t *testing.T) { for _, id := range tls13CipherSuites { t.Run(tls.CipherSuiteName(id), func(t *testing.T) { cs := cipherSuiteTLS13ByID(id) expected := nextTrafficSecret(cs, []byte("foobar")) expanded := hkdfExpandLabel(cs.Hash, []byte("foobar"), nil, "traffic upd", cs.Hash.Size()) require.Equal(t, expected, expanded) }) } } // As of Go 1.24, the standard library and our implementation of hkdfExpandLabel should provide the same performance. func BenchmarkHKDFExpandLabelStandardLibrary(b *testing.B) { for _, id := range tls13CipherSuites { b.Run(tls.CipherSuiteName(id), func(b *testing.B) { benchmarkHKDFExpandLabel(b, id, true) }) } } func BenchmarkHKDFExpandLabelOurs(b *testing.B) { for _, id := range tls13CipherSuites { b.Run(tls.CipherSuiteName(id), func(b *testing.B) { benchmarkHKDFExpandLabel(b, id, false) }) } } func benchmarkHKDFExpandLabel(b *testing.B, cipherSuite uint16, useStdLib bool) { b.ReportAllocs() cs := cipherSuiteTLS13ByID(cipherSuite) secret := make([]byte, 32) rand.Read(secret) for b.Loop() { if useStdLib { nextTrafficSecret(cs, secret) } else { hkdfExpandLabel(cs.Hash, secret, nil, "traffic upd", cs.Hash.Size()) } } } quic-go-0.59.0/internal/handshake/initial_aead.go000066400000000000000000000045441513066070600216600ustar00rootroot00000000000000package handshake import ( "crypto" "crypto/tls" "golang.org/x/crypto/hkdf" "github.com/quic-go/quic-go/internal/protocol" ) var ( quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9} ) const ( hkdfLabelKeyV1 = "quic key" hkdfLabelKeyV2 = "quicv2 key" hkdfLabelIVV1 = "quic iv" hkdfLabelIVV2 = "quicv2 iv" ) func getSalt(v protocol.Version) []byte { if v == protocol.Version2 { return quicSaltV2 } return quicSaltV1 } var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256) // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) { clientSecret, serverSecret := computeSecrets(connID, v) var mySecret, otherSecret []byte if pers == protocol.PerspectiveClient { mySecret = clientSecret otherSecret = serverSecret } else { mySecret = serverSecret otherSecret = clientSecret } myKey, myIV := computeInitialKeyAndIV(mySecret, v) otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) encrypter := initialSuite.AEAD(myKey, myIV) decrypter := initialSuite.AEAD(otherKey, otherIV) return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) } func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) { initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) return } func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { keyLabel = hkdfLabelKeyV2 ivLabel = hkdfLabelIVV2 } key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) return } quic-go-0.59.0/internal/handshake/initial_aead_test.go000066400000000000000000000435351513066070600227220ustar00rootroot00000000000000package handshake import ( "bytes" "crypto/rand" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestComputeClientKeyAndIV(t *testing.T) { connID := protocol.ParseConnectionID(splitHexString(t, "0x8394c8f03e515708")) tests := []struct { name string version protocol.Version expectedClientSecret []byte expectedKey []byte expectedIV []byte }{ { name: "QUIC v1", version: protocol.Version1, expectedClientSecret: splitHexString(t, "c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"), expectedKey: splitHexString(t, "1f369613dd76d5467730efcbe3b1a22d"), expectedIV: splitHexString(t, "fa044b2f42a3fd3b46fb255c"), }, { name: "QUIC v2", version: protocol.Version2, expectedClientSecret: splitHexString(t, "14ec9d6eb9fd7af83bf5a668bc17a7e2 83766aade7ecd0891f70f9ff7f4bf47b"), expectedKey: splitHexString(t, "8b1a0bc121284290a29e0971b5cd045d"), expectedIV: splitHexString(t, "91f73e2351d8fa91660e909f"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { clientSecret, _ := computeSecrets(connID, tt.version) require.Equal(t, tt.expectedClientSecret, clientSecret) key, iv := computeInitialKeyAndIV(clientSecret, tt.version) require.Equal(t, tt.expectedKey, key) require.Equal(t, tt.expectedIV, iv) }) } } func TestComputeServerKeyAndIV(t *testing.T) { connID := protocol.ParseConnectionID(splitHexString(t, "0x8394c8f03e515708")) tests := []struct { name string version protocol.Version expectedServerSecret []byte expectedKey []byte expectedIV []byte }{ { name: "QUIC v1", version: protocol.Version1, expectedServerSecret: splitHexString(t, "3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"), expectedKey: splitHexString(t, "cf3a5331653c364c88f0f379b6067e37"), expectedIV: splitHexString(t, "0ac1493ca1905853b0bba03e"), }, { name: "QUIC v2", version: protocol.Version2, expectedServerSecret: splitHexString(t, "0263db1782731bf4588e7e4d93b74639 07cb8cd8200b5da55a8bd488eafc37c1"), expectedKey: splitHexString(t, "82db637861d55e1d011f19ea71d5d2a7"), expectedIV: splitHexString(t, "dd13c276499c0249d3310652"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, serverSecret := computeSecrets(connID, tt.version) require.Equal(t, tt.expectedServerSecret, serverSecret) key, iv := computeInitialKeyAndIV(serverSecret, tt.version) require.Equal(t, tt.expectedKey, key) require.Equal(t, tt.expectedIV, iv) }) } } func TestClientInitial(t *testing.T) { connID := protocol.ParseConnectionID(splitHexString(t, "0x8394c8f03e515708")) tests := []struct { version protocol.Version header []byte data []byte expectedSample []byte expectedHdrFirstByte byte expectedHdr []byte expectedPacket []byte }{ { version: protocol.Version1, header: splitHexString(t, "c300000001088394c8f03e5157080000449e00000002"), data: splitHexString(t, "060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), expectedSample: splitHexString(t, "d1b1c98dd7689fb8ec11d242b123dc9b"), expectedHdrFirstByte: byte(0xc0), expectedHdr: splitHexString(t, "7b9aec34"), expectedPacket: splitHexString(t, "c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"), }, { version: protocol.Version2, header: splitHexString(t, "d36b3343cf088394c8f03e5157080000449e00000002"), data: splitHexString(t, "060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), expectedSample: splitHexString(t, "ffe67b6abcdb4298b485dd04de806071"), expectedHdrFirstByte: byte(0xd7), expectedHdr: splitHexString(t, "a0c95e82"), expectedPacket: splitHexString(t, "d76b3343cf088394c8f03e5157080000 449ea0c95e82ffe67b6abcdb4298b485 dd04de806071bf03dceebfa162e75d6c 96058bdbfb127cdfcbf903388e99ad04 9f9a3dd4425ae4d0992cfff18ecf0fdb 5a842d09747052f17ac2053d21f57c5d 250f2c4f0e0202b70785b7946e992e58 a59ac52dea6774d4f03b55545243cf1a 12834e3f249a78d395e0d18f4d766004 f1a2674802a747eaa901c3f10cda5500 cb9122faa9f1df66c392079a1b40f0de 1c6054196a11cbea40afb6ef5253cd68 18f6625efce3b6def6ba7e4b37a40f77 32e093daa7d52190935b8da58976ff33 12ae50b187c1433c0f028edcc4c2838b 6a9bfc226ca4b4530e7a4ccee1bfa2a3 d396ae5a3fb512384b2fdd851f784a65 e03f2c4fbe11a53c7777c023462239dd 6f7521a3f6c7d5dd3ec9b3f233773d4b 46d23cc375eb198c63301c21801f6520 bcfb7966fc49b393f0061d974a2706df 8c4a9449f11d7f3d2dcbb90c6b877045 636e7c0c0fe4eb0f697545460c806910 d2c355f1d253bc9d2452aaa549e27a1f ac7cf4ed77f322e8fa894b6a83810a34 b361901751a6f5eb65a0326e07de7c12 16ccce2d0193f958bb3850a833f7ae43 2b65bc5a53975c155aa4bcb4f7b2c4e5 4df16efaf6ddea94e2c50b4cd1dfe060 17e0e9d02900cffe1935e0491d77ffb4 fdf85290fdd893d577b1131a610ef6a5 c32b2ee0293617a37cbb08b847741c3b 8017c25ca9052ca1079d8b78aebd4787 6d330a30f6a8c6d61dd1ab5589329de7 14d19d61370f8149748c72f132f0fc99 f34d766c6938597040d8f9e2bb522ff9 9c63a344d6a2ae8aa8e51b7b90a4a806 105fcbca31506c446151adfeceb51b91 abfe43960977c87471cf9ad4074d30e1 0d6a7f03c63bd5d4317f68ff325ba3bd 80bf4dc8b52a0ba031758022eb025cdd 770b44d6d6cf0670f4e990b22347a7db 848265e3e5eb72dfe8299ad7481a4083 22cac55786e52f633b2fb6b614eaed18 d703dd84045a274ae8bfa73379661388 d6991fe39b0d93debb41700b41f90a15 c4d526250235ddcd6776fc77bc97e7a4 17ebcb31600d01e57f32162a8560cacc 7e27a096d37a1a86952ec71bd89a3e9a 30a2a26162984d7740f81193e8238e61 f6b5b984d4d3dfa033c1bb7e4f0037fe bf406d91c0dccf32acf423cfa1e70710 10d3f270121b493ce85054ef58bada42 310138fe081adb04e2bd901f2f13458b 3d6758158197107c14ebb193230cd115 7380aa79cae1374a7c1e5bbcb80ee23e 06ebfde206bfb0fcbc0edc4ebec30966 1bdd908d532eb0c6adc38b7ca7331dce 8dfce39ab71e7c32d318d136b6100671 a1ae6a6600e3899f31f0eed19e3417d1 34b90c9058f8632c798d4490da498730 7cba922d61c39805d072b589bd52fdf1 e86215c2d54e6670e07383a27bbffb5a ddf47d66aa85a0c6f9f32e59d85a44dd 5d3b22dc2be80919b490437ae4f36a0a e55edf1d0b5cb4e9a3ecabee93dfc6e3 8d209d0fa6536d27a5d6fbb17641cde2 7525d61093f1b28072d111b2b4ae5f89 d5974ee12e5cf7d5da4d6a31123041f3 3e61407e76cffcdcfd7e19ba58cf4b53 6f4c4938ae79324dc402894b44faf8af bab35282ab659d13c93f70412e85cb19 9a37ddec600545473cfb5a05e08d0b20 9973b2172b4d21fb69745a262ccde96b a18b2faa745b6fe189cf772a9f84cbfc"), }, } for _, tt := range tests { t.Run(tt.version.String(), func(t *testing.T) { sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, tt.version) tt.data = append(tt.data, make([]byte, 1162-len(tt.data))...) // add PADDING sealed := sealer.Seal(nil, tt.data, 2, tt.header) sample := sealed[0:16] require.Equal(t, tt.expectedSample, sample) sealer.EncryptHeader(sample, &tt.header[0], tt.header[len(tt.header)-4:]) require.Equal(t, tt.expectedHdrFirstByte, tt.header[0]) require.Equal(t, tt.expectedHdr, tt.header[len(tt.header)-4:]) packet := append(tt.header, sealed...) require.Equal(t, tt.expectedPacket, packet) }) } } func TestServersInitial(t *testing.T) { connID := protocol.ParseConnectionID(splitHexString(t, "0x8394c8f03e515708")) testCases := []struct { name string version protocol.Version header []byte data []byte expectedSample []byte expectedHdr []byte expectedPacket []byte }{ { name: "QUIC v1", version: protocol.Version1, header: splitHexString(t, "c1000000010008f067a5502a4262b50040750001"), data: splitHexString(t, "02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), expectedSample: splitHexString(t, "2cd0991cd25b0aac406a5816b6394100"), expectedHdr: splitHexString(t, "cf000000010008f067a5502a4262b5004075c0d9"), expectedPacket: splitHexString(t, "cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"), }, { name: "QUIC v2", version: protocol.Version2, header: splitHexString(t, "d16b3343cf0008f067a5502a4262b50040750001"), data: splitHexString(t, "02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), expectedSample: splitHexString(t, "6f05d8a4398c47089698baeea26b91eb"), expectedHdr: splitHexString(t, "dc6b3343cf0008f067a5502a4262b5004075d92f"), expectedPacket: splitHexString(t, "dc6b3343cf0008f067a5502a4262b500 4075d92faaf16f05d8a4398c47089698 baeea26b91eb761d9b89237bbf872630 17915358230035f7fd3945d88965cf17 f9af6e16886c61bfc703106fbaf3cb4c fa52382dd16a393e42757507698075b2 c984c707f0a0812d8cd5a6881eaf21ce da98f4bd23f6fe1a3e2c43edd9ce7ca8 4bed8521e2e140"), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, tc.version) sealed := sealer.Seal(nil, tc.data, 1, tc.header) sample := sealed[2 : 2+16] require.Equal(t, tc.expectedSample, sample) sealer.EncryptHeader(sample, &tc.header[0], tc.header[len(tc.header)-2:]) require.Equal(t, tc.expectedHdr, tc.header) packet := append(tc.header, sealed...) require.Equal(t, tc.expectedPacket, packet) }) } } func TestInitialAEADSealsAndOpens(t *testing.T) { for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { t.Run(ver.String(), func(t *testing.T) { connectionID := protocol.ParseConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, ver) serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, ver) clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) require.NoError(t, err) require.Equal(t, []byte("foobar"), m) serverMessage := serverSealer.Seal(nil, []byte("raboof"), 99, []byte("daa")) m, err = clientOpener.Open(nil, serverMessage, 99, []byte("daa")) require.NoError(t, err) require.Equal(t, []byte("raboof"), m) }) } } func TestInitialAEADFailsWithDifferentConnectionIDs(t *testing.T) { for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { t.Run(ver.String(), func(t *testing.T) { c1 := protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) c2 := protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2}) clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, ver) _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, ver) clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) _, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) require.Equal(t, ErrDecryptionFailed, err) }) } } func TestInitialAEADEncryptsAndDecryptsHeader(t *testing.T) { for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { t.Run(ver.String(), func(t *testing.T) { connID := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, ver) serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, ver) header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} sample := make([]byte, 16) rand.Read(sample) clientSealer.EncryptHeader(sample, &header[0], header[6:10]) require.Equal(t, byte(0x5e&0xf0), header[0]&0xf0) require.Equal(t, []byte{0, 1, 2, 3, 4}, header[1:6]) require.NotEqual(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[6:10]) serverOpener.DecryptHeader(sample, &header[0], header[6:10]) require.Equal(t, byte(0x5e), header[0]) require.Equal(t, []byte{0, 1, 2, 3, 4}, header[1:6]) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[6:10]) serverSealer.EncryptHeader(sample, &header[0], header[6:10]) require.Equal(t, byte(0x5e&0xf0), header[0]&0xf0) require.Equal(t, []byte{0, 1, 2, 3, 4}, header[1:6]) require.NotEqual(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[6:10]) clientOpener.DecryptHeader(sample, &header[0], header[6:10]) require.Equal(t, byte(0x5e), header[0]) require.Equal(t, []byte{0, 1, 2, 3, 4}, header[1:6]) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[6:10]) }) } } func BenchmarkInitialAEADCreate(b *testing.B) { b.ReportAllocs() connID := protocol.ParseConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) for b.Loop() { NewInitialAEAD(connID, protocol.PerspectiveServer, protocol.Version1) } } func BenchmarkInitialAEAD(b *testing.B) { connectionID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd}) clientSealer, _ := NewInitialAEAD(connectionID, protocol.PerspectiveClient, protocol.Version1) _, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, protocol.Version1) packetData := make([]byte, 1200) rand.Read(packetData) hdr := make([]byte, 50) rand.Read(hdr) msg := clientSealer.Seal(nil, packetData, 42, hdr) m, err := serverOpener.Open(nil, msg, 42, hdr) if err != nil { b.Fatalf("opening failed: %s", err) } if !bytes.Equal(m, packetData) { b.Fatal("decrypted data doesn't match") } b.ResetTimer() b.Run("opening 100 bytes", func(b *testing.B) { benchmarkOpen(b, serverOpener, clientSealer.Seal(nil, packetData[:100], 42, hdr), hdr) }) b.Run("opening 1200 bytes", func(b *testing.B) { benchmarkOpen(b, serverOpener, msg, hdr) }) b.Run("sealing 100 bytes", func(b *testing.B) { benchmarkSeal(b, clientSealer, packetData[:100], hdr) }) b.Run("sealing 1200 bytes", func(b *testing.B) { benchmarkSeal(b, clientSealer, packetData, hdr) }) } func benchmarkOpen(b *testing.B, aead LongHeaderOpener, msg, hdr []byte) { b.ReportAllocs() dst := make([]byte, 0, 1500) for b.Loop() { dst = dst[:0] if _, err := aead.Open(dst, msg, 42, hdr); err != nil { b.Fatalf("opening failed: %s", err) } } } func benchmarkSeal(b *testing.B, aead LongHeaderSealer, msg, hdr []byte) { b.ReportAllocs() dst := make([]byte, 0, 1500) var pn protocol.PacketNumber for b.Loop() { dst = dst[:0] aead.Seal(dst, msg, pn, hdr) pn++ } } quic-go-0.59.0/internal/handshake/interface.go000066400000000000000000000112151513066070600212060ustar00rootroot00000000000000package handshake import ( "context" "crypto/tls" "errors" "io" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) var ( // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, // but the corresponding opener has not yet been initialized // This can happen when packets arrive out of order. ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, // but the corresponding keys have already been dropped. ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") // ErrDecryptionFailed is returned when the AEAD fails to open the packet. ErrDecryptionFailed = errors.New("decryption failed") ) type headerDecryptor interface { DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } // LongHeaderOpener opens a long header packet type LongHeaderOpener interface { headerDecryptor DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) } // ShortHeaderOpener opens a short header packet type ShortHeaderOpener interface { headerDecryptor DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber Open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) } // LongHeaderSealer seals a long header packet type LongHeaderSealer interface { Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) Overhead() int } // ShortHeaderSealer seals a short header packet type ShortHeaderSealer interface { LongHeaderSealer KeyPhase() protocol.KeyPhaseBit } type ConnectionState struct { tls.ConnectionState Used0RTT bool } // EventKind is the kind of handshake event. type EventKind uint8 const ( // EventNoEvent signals that there are no new handshake events EventNoEvent EventKind = iota + 1 // EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level EventWriteInitialData // EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level EventWriteHandshakeData // EventReceivedReadKeys signals that new decryption keys are available. // It doesn't say which encryption level those keys are for. EventReceivedReadKeys // EventDiscard0RTTKeys signals that the Handshake keys were discarded. EventDiscard0RTTKeys // EventReceivedTransportParameters contains the transport parameters sent by the peer. EventReceivedTransportParameters // EventRestoredTransportParameters contains the transport parameters restored from the session ticket. // It is only used for the client. EventRestoredTransportParameters // EventHandshakeComplete signals that the TLS handshake was completed. EventHandshakeComplete ) func (k EventKind) String() string { switch k { case EventNoEvent: return "EventNoEvent" case EventWriteInitialData: return "EventWriteInitialData" case EventWriteHandshakeData: return "EventWriteHandshakeData" case EventReceivedReadKeys: return "EventReceivedReadKeys" case EventDiscard0RTTKeys: return "EventDiscard0RTTKeys" case EventReceivedTransportParameters: return "EventReceivedTransportParameters" case EventRestoredTransportParameters: return "EventRestoredTransportParameters" case EventHandshakeComplete: return "EventHandshakeComplete" default: return "Unknown EventKind" } } // Event is a handshake event. type Event struct { Kind EventKind Data []byte TransportParameters *wire.TransportParameters } // CryptoSetup handles the handshake and protecting / unprotecting packets type CryptoSetup interface { StartHandshake(context.Context) error io.Closer ChangeConnectionID(protocol.ConnectionID) GetSessionTicket() ([]byte, error) HandleMessage([]byte, protocol.EncryptionLevel) error NextEvent() Event SetLargest1RTTAcked(protocol.PacketNumber) error DiscardInitialKeys() SetHandshakeConfirmed() ConnectionState() ConnectionState GetInitialOpener() (LongHeaderOpener, error) GetHandshakeOpener() (LongHeaderOpener, error) Get0RTTOpener() (LongHeaderOpener, error) Get1RTTOpener() (ShortHeaderOpener, error) GetInitialSealer() (LongHeaderSealer, error) GetHandshakeSealer() (LongHeaderSealer, error) Get0RTTSealer() (LongHeaderSealer, error) Get1RTTSealer() (ShortHeaderSealer, error) } quic-go-0.59.0/internal/handshake/retry.go000066400000000000000000000035521513066070600204200ustar00rootroot00000000000000package handshake import ( "bytes" "crypto/aes" "crypto/cipher" "fmt" "sync" "github.com/quic-go/quic-go/internal/protocol" ) // Instead of using an init function, the AEADs are created lazily. // For more details see https://github.com/quic-go/quic-go/issues/4894. var ( retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000) retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369) ) func initAEAD(key [16]byte) cipher.AEAD { aes, err := aes.NewCipher(key[:]) if err != nil { panic(err) } aead, err := cipher.NewGCM(aes) if err != nil { panic(err) } return aead } var ( retryBuf bytes.Buffer retryMutex sync.Mutex retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a} ) // GetRetryIntegrityTag calculates the integrity tag on a Retry packet func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte { retryMutex.Lock() defer retryMutex.Unlock() retryBuf.WriteByte(uint8(origDestConnID.Len())) retryBuf.Write(origDestConnID.Bytes()) retryBuf.Write(retry) defer retryBuf.Reset() var tag [16]byte var sealed []byte if version == protocol.Version2 { if retryAEADv2 == nil { retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92}) } sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes()) } else { if retryAEADv1 == nil { retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) } sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes()) } if len(sealed) != 16 { panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) } return &tag } quic-go-0.59.0/internal/handshake/retry_test.go000066400000000000000000000033221513066070600214520ustar00rootroot00000000000000package handshake import ( "encoding/binary" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestRetryIntegrityTagCalculation(t *testing.T) { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) fooTag := GetRetryIntegrityTag([]byte("foo"), connID, protocol.Version1) barTag := GetRetryIntegrityTag([]byte("bar"), connID, protocol.Version1) require.NotNil(t, fooTag) require.NotNil(t, barTag) require.NotEqual(t, *fooTag, *barTag) } func TestRetryIntegrityTagWithDifferentConnectionIDs(t *testing.T) { connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) t1 := GetRetryIntegrityTag([]byte("foobar"), connID1, protocol.Version1) t2 := GetRetryIntegrityTag([]byte("foobar"), connID2, protocol.Version1) require.NotEqual(t, *t1, *t2) } func TestRetryIntegrityTagWithTestVectors(t *testing.T) { tests := []struct { name string version protocol.Version data []byte }{ { name: "v1", version: protocol.Version1, data: splitHexString(t, "ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba"), }, { name: "v2", version: protocol.Version2, data: splitHexString(t, "cf6b3343cf0008f067a5502a4262b574 6f6b656ec8646ce8bfe33952d9555436 65dcc7b6"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { v := binary.BigEndian.Uint32(tt.data[1:5]) require.Equal(t, tt.version, protocol.Version(v)) connID := protocol.ParseConnectionID(splitHexString(t, "0x8394c8f03e515708")) tag := GetRetryIntegrityTag(tt.data[:len(tt.data)-16], connID, tt.version) require.Equal(t, tt.data[len(tt.data)-16:], tag[:]) }) } } quic-go-0.59.0/internal/handshake/session_ticket.go000066400000000000000000000024431513066070600222770ustar00rootroot00000000000000package handshake import ( "bytes" "errors" "fmt" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/quicvarint" ) const sessionTicketRevision = 5 type sessionTicket struct { Parameters *wire.TransportParameters } func (t *sessionTicket) Marshal() []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, sessionTicketRevision) return t.Parameters.MarshalForSessionTicket(b) } func (t *sessionTicket) Unmarshal(b []byte) error { rev, l, err := quicvarint.Parse(b) if err != nil { return errors.New("failed to read session ticket revision") } b = b[l:] if rev != sessionTicketRevision { return fmt.Errorf("unknown session ticket revision: %d", rev) } var tp wire.TransportParameters if err := tp.UnmarshalFromSessionTicket(b); err != nil { return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) } t.Parameters = &tp return nil } const extraPrefix = "quic-go1" func addSessionStateExtraPrefix(b []byte) []byte { return append([]byte(extraPrefix), b...) } func findSessionStateExtraData(extras [][]byte) []byte { prefix := []byte(extraPrefix) for _, extra := range extras { if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) { continue } return extra[len(prefix):] } return nil } quic-go-0.59.0/internal/handshake/session_ticket_test.go000066400000000000000000000027451513066070600233430ustar00rootroot00000000000000package handshake import ( "testing" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestMarshalUnmarshalSessionTicket(t *testing.T) { ticket := &sessionTicket{ Parameters: &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiRemote: 2, ActiveConnectionIDLimit: 10, MaxDatagramFrameSize: 20, }, } var t2 sessionTicket require.NoError(t, t2.Unmarshal(ticket.Marshal())) require.EqualValues(t, 1, t2.Parameters.InitialMaxStreamDataBidiLocal) require.EqualValues(t, 2, t2.Parameters.InitialMaxStreamDataBidiRemote) require.EqualValues(t, 10, t2.Parameters.ActiveConnectionIDLimit) require.EqualValues(t, 20, t2.Parameters.MaxDatagramFrameSize) } func TestUnmarshalRefusesTooShortTicket(t *testing.T) { err := (&sessionTicket{}).Unmarshal([]byte{}) require.EqualError(t, err, "failed to read session ticket revision") } func TestUnmarshalRefusesUnknownRevision(t *testing.T) { b := quicvarint.Append(nil, 1337) err := (&sessionTicket{}).Unmarshal(b) require.EqualError(t, err, "unknown session ticket revision: 1337") } func TestUnmarshal0RTTRefusesInvalidTransportParameters(t *testing.T) { b := quicvarint.Append(nil, sessionTicketRevision) b = append(b, []byte("foobar")...) err := (&sessionTicket{}).Unmarshal(b) require.Error(t, err) require.Contains(t, err.Error(), "unmarshaling transport parameters from session ticket failed") } quic-go-0.59.0/internal/handshake/tls_config.go000066400000000000000000000024401513066070600213750ustar00rootroot00000000000000package handshake import ( "crypto/tls" "net" ) func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config { // Workaround for https://github.com/golang/go/issues/60506. // This initializes the session tickets _before_ cloning the config. _, _ = conf.DecryptTicket(nil, tls.ConnectionState{}) conf = conf.Clone() conf.MinVersion = tls.VersionTLS13 // The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo. // Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn // that allows the caller to get the local and the remote address. if conf.GetConfigForClient != nil { gcfc := conf.GetConfigForClient conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} c, err := gcfc(info) if c != nil { // we're returning a tls.Config here, so we need to apply this recursively c = setupConfigForServer(c, localAddr, remoteAddr) } return c, err } } if conf.GetCertificate != nil { gc := conf.GetCertificate conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} return gc(info) } } return conf } quic-go-0.59.0/internal/handshake/tls_config_test.go000066400000000000000000000064401513066070600224400ustar00rootroot00000000000000package handshake import ( "crypto/tls" "net" "reflect" "testing" "github.com/stretchr/testify/require" ) func TestMinimumTLSVersion(t *testing.T) { local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} orig := &tls.Config{MinVersion: tls.VersionTLS12} conf := setupConfigForServer(orig, local, remote) require.EqualValues(t, tls.VersionTLS13, conf.MinVersion) // check that the original config wasn't modified require.EqualValues(t, tls.VersionTLS12, orig.MinVersion) } func TestServerConfigGetCertificate(t *testing.T) { local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} var localAddr, remoteAddr net.Addr tlsConf := &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { localAddr = info.Conn.LocalAddr() remoteAddr = info.Conn.RemoteAddr() return &tls.Certificate{}, nil }, } conf := setupConfigForServer(tlsConf, local, remote) _, err := conf.GetCertificate(&tls.ClientHelloInfo{}) require.NoError(t, err) require.Equal(t, local, localAddr) require.Equal(t, remote, remoteAddr) } func TestServerConfigGetConfigForClient(t *testing.T) { local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} var localAddr, remoteAddr net.Addr tlsConf := setupConfigForServer( &tls.Config{ GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { localAddr = info.Conn.LocalAddr() remoteAddr = info.Conn.RemoteAddr() return &tls.Config{}, nil }, }, local, remote, ) conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) require.NoError(t, err) require.Equal(t, local, localAddr) require.Equal(t, remote, remoteAddr) require.NotNil(t, conf) require.EqualValues(t, tls.VersionTLS13, conf.MinVersion) } func TestServerConfigGetConfigForClientRecursively(t *testing.T) { local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} remote := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} var localAddr, remoteAddr net.Addr tlsConf := &tls.Config{} var innerConf *tls.Config getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { localAddr = info.Conn.LocalAddr() remoteAddr = info.Conn.RemoteAddr() return &tls.Certificate{}, nil } tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { innerConf = tlsConf.Clone() // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config innerConf.MaxVersion = tls.VersionTLS12 innerConf.GetCertificate = getCert return innerConf, nil } tlsConf = setupConfigForServer(tlsConf, local, remote) conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) require.NoError(t, err) require.NotNil(t, conf) require.EqualValues(t, tls.VersionTLS13, conf.MinVersion) _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) require.NoError(t, err) require.Equal(t, local, localAddr) require.Equal(t, remote, remoteAddr) // make sure that the tls.Config returned by GetConfigForClient isn't modified require.True(t, reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()) require.EqualValues(t, tls.VersionTLS12, innerConf.MaxVersion) } quic-go-0.59.0/internal/handshake/token_generator.go000066400000000000000000000070301513066070600224340ustar00rootroot00000000000000package handshake import ( "bytes" "encoding/asn1" "fmt" "net" "time" "github.com/quic-go/quic-go/internal/protocol" ) const ( tokenPrefixIP byte = iota tokenPrefixString ) // A Token is derived from the client address and can be used to verify the ownership of this address. type Token struct { IsRetryToken bool SentTime time.Time encodedRemoteAddr []byte // only set for tokens sent in NEW_TOKEN frames RTT time.Duration // only set for retry tokens OriginalDestConnectionID protocol.ConnectionID RetrySrcConnectionID protocol.ConnectionID } // ValidateRemoteAddr validates the address, but does not check expiration func (t *Token) ValidateRemoteAddr(addr net.Addr) bool { return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr) } // token is the struct that is used for ASN1 serialization and deserialization type token struct { IsRetryToken bool RemoteAddr []byte Timestamp int64 RTT int64 // in mus OriginalDestConnectionID []byte RetrySrcConnectionID []byte } // A TokenGenerator generates tokens type TokenGenerator struct { tokenProtector tokenProtector } // NewTokenGenerator initializes a new TokenGenerator func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator { return &TokenGenerator{tokenProtector: *newTokenProtector(key)} } // NewRetryToken generates a new token for a Retry for a given source address func (g *TokenGenerator) NewRetryToken( raddr net.Addr, origDestConnID protocol.ConnectionID, retrySrcConnID protocol.ConnectionID, ) ([]byte, error) { data, err := asn1.Marshal(token{ IsRetryToken: true, RemoteAddr: encodeRemoteAddr(raddr), OriginalDestConnectionID: origDestConnID.Bytes(), RetrySrcConnectionID: retrySrcConnID.Bytes(), Timestamp: time.Now().UnixNano(), }) if err != nil { return nil, err } return g.tokenProtector.NewToken(data) } // NewToken generates a new token to be sent in a NEW_TOKEN frame func (g *TokenGenerator) NewToken(raddr net.Addr, rtt time.Duration) ([]byte, error) { data, err := asn1.Marshal(token{ RemoteAddr: encodeRemoteAddr(raddr), Timestamp: time.Now().UnixNano(), RTT: rtt.Microseconds(), }) if err != nil { return nil, err } return g.tokenProtector.NewToken(data) } // DecodeToken decodes a token func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { // if the client didn't send any token, DecodeToken will be called with a nil-slice if len(encrypted) == 0 { return nil, nil } data, err := g.tokenProtector.DecodeToken(encrypted) if err != nil { return nil, err } t := &token{} rest, err := asn1.Unmarshal(data, t) if err != nil { return nil, err } if len(rest) != 0 { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } token := &Token{ IsRetryToken: t.IsRetryToken, SentTime: time.Unix(0, t.Timestamp), encodedRemoteAddr: t.RemoteAddr, } if t.IsRetryToken { token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID) token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID) } else { token.RTT = time.Duration(t.RTT) * time.Microsecond } return token, nil } // encodeRemoteAddr encodes a remote address such that it can be saved in the token func encodeRemoteAddr(remoteAddr net.Addr) []byte { if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { return append([]byte{tokenPrefixIP}, udpAddr.IP...) } return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) } quic-go-0.59.0/internal/handshake/token_generator_test.go000066400000000000000000000105461513066070600235010ustar00rootroot00000000000000package handshake import ( "crypto/rand" "encoding/asn1" "net" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func newTokenGenerator(t *testing.T) *TokenGenerator { var key TokenProtectorKey _, err := rand.Read(key[:]) require.NoError(t, err) return NewTokenGenerator(key) } func TestTokenGeneratorNilTokens(t *testing.T) { tokenGen := newTokenGenerator(t) nilToken, err := tokenGen.DecodeToken(nil) require.NoError(t, err) require.Nil(t, nilToken) } func TestTokenGeneratorValidToken(t *testing.T) { tokenGen := newTokenGenerator(t) addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) connID2 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) tokenEnc, err := tokenGen.NewRetryToken(addr, connID1, connID2) require.NoError(t, err) decodedToken, err := tokenGen.DecodeToken(tokenEnc) require.NoError(t, err) require.True(t, decodedToken.ValidateRemoteAddr(addr)) require.False(t, decodedToken.ValidateRemoteAddr(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 2), Port: 1337})) require.WithinDuration(t, time.Now(), decodedToken.SentTime, 100*time.Millisecond) require.Equal(t, connID1, decodedToken.OriginalDestConnectionID) require.Equal(t, connID2, decodedToken.RetrySrcConnectionID) } func TestTokenGeneratorRejectsInvalidTokens(t *testing.T) { tokenGen := newTokenGenerator(t) _, err := tokenGen.DecodeToken([]byte("invalid token")) require.Error(t, err) require.Contains(t, err.Error(), "too short") } func TestTokenGeneratorDecodingFailed(t *testing.T) { tokenGen := newTokenGenerator(t) invalidToken, err := tokenGen.tokenProtector.NewToken([]byte("foobar")) require.NoError(t, err) _, err = tokenGen.DecodeToken(invalidToken) require.Error(t, err) require.Contains(t, err.Error(), "asn1") } func TestTokenGeneratorAdditionalPayload(t *testing.T) { tokenGen := newTokenGenerator(t) tok, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) require.NoError(t, err) tok = append(tok, []byte("rest")...) enc, err := tokenGen.tokenProtector.NewToken(tok) require.NoError(t, err) _, err = tokenGen.DecodeToken(enc) require.EqualError(t, err, "rest when unpacking token: 4") } func TestTokenGeneratorEmptyTokens(t *testing.T) { tokenGen := newTokenGenerator(t) emptyTok, err := asn1.Marshal(token{RemoteAddr: []byte("")}) require.NoError(t, err) emptyEnc, err := tokenGen.tokenProtector.NewToken(emptyTok) require.NoError(t, err) _, err = tokenGen.DecodeToken(emptyEnc) require.NoError(t, err) } func TestTokenGeneratorIPv6(t *testing.T) { tokenGen := newTokenGenerator(t) addresses := []string{ "2001:db8::68", "2001:0000:4136:e378:8000:63bf:3fff:fdd2", "2001::1", "ff01:0:0:0:0:0:0:2", } for _, addr := range addresses { ip := net.ParseIP(addr) require.NotNil(t, ip) raddr := &net.UDPAddr{IP: ip, Port: 1337} tokenEnc, err := tokenGen.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) require.NoError(t, err) token, err := tokenGen.DecodeToken(tokenEnc) require.NoError(t, err) require.True(t, token.ValidateRemoteAddr(raddr)) require.WithinDuration(t, time.Now(), token.SentTime, 100*time.Millisecond) } } func TestTokenGeneratorNonUDPAddr(t *testing.T) { tokenGen := newTokenGenerator(t) raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} tokenEnc, err := tokenGen.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) require.NoError(t, err) token, err := tokenGen.DecodeToken(tokenEnc) require.NoError(t, err) require.True(t, token.ValidateRemoteAddr(raddr)) require.False(t, token.ValidateRemoteAddr(&net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1338})) require.WithinDuration(t, time.Now(), token.SentTime, 100*time.Millisecond) } func BenchmarkTokenGeneratorDecodeToken(b *testing.B) { b.ReportAllocs() var key TokenProtectorKey _, err := rand.Read(key[:]) require.NoError(b, err) tokenGen := NewTokenGenerator(key) addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) tokenEnc, err := tokenGen.NewRetryToken(addr, connID1, connID2) require.NoError(b, err) for b.Loop() { if _, err := tokenGen.DecodeToken(tokenEnc); err != nil { b.Fatal(err) } } } quic-go-0.59.0/internal/handshake/token_protector.go000066400000000000000000000035341513066070600224740ustar00rootroot00000000000000package handshake import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha256" "fmt" "io" "golang.org/x/crypto/hkdf" ) // TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens. type TokenProtectorKey [32]byte const tokenNonceSize = 32 // tokenProtector is used to create and verify a token type tokenProtector struct { key TokenProtectorKey } // newTokenProtector creates a source for source address tokens func newTokenProtector(key TokenProtectorKey) *tokenProtector { return &tokenProtector{key: key} } // NewToken encodes data into a new token. func (s *tokenProtector) NewToken(data []byte) ([]byte, error) { var nonce [tokenNonceSize]byte if _, err := rand.Read(nonce[:]); err != nil { return nil, err } aead, aeadNonce, err := s.createAEAD(nonce[:]) if err != nil { return nil, err } return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil } // DecodeToken decodes a token. func (s *tokenProtector) DecodeToken(p []byte) ([]byte, error) { if len(p) < tokenNonceSize { return nil, fmt.Errorf("token too short: %d", len(p)) } nonce := p[:tokenNonceSize] aead, aeadNonce, err := s.createAEAD(nonce) if err != nil { return nil, err } return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) } func (s *tokenProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source")) key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 if _, err := io.ReadFull(h, key); err != nil { return nil, nil, err } aeadNonce := make([]byte, 12) if _, err := io.ReadFull(h, aeadNonce); err != nil { return nil, nil, err } c, err := aes.NewCipher(key) if err != nil { return nil, nil, err } aead, err := cipher.NewGCM(c) if err != nil { return nil, nil, err } return aead, aeadNonce, nil } quic-go-0.59.0/internal/handshake/token_protector_test.go000066400000000000000000000030701513066070600235260ustar00rootroot00000000000000package handshake import ( "crypto/rand" "testing" "github.com/stretchr/testify/require" ) func TestTokenProtectorEncodeAndDecode(t *testing.T) { var key TokenProtectorKey rand.Read(key[:]) tp := newTokenProtector(key) token, err := tp.NewToken([]byte("foobar")) require.NoError(t, err) require.NotContains(t, string(token), "foobar") decoded, err := tp.DecodeToken(token) require.NoError(t, err) require.Equal(t, []byte("foobar"), decoded) } func TestTokenProtectorDifferentKeys(t *testing.T) { var key1, key2 TokenProtectorKey rand.Read(key1[:]) rand.Read(key2[:]) tp1 := newTokenProtector(key1) tp2 := newTokenProtector(key2) t1, err := tp1.NewToken([]byte("foo")) require.NoError(t, err) t2, err := tp2.NewToken([]byte("foo")) require.NoError(t, err) _, err = tp1.DecodeToken(t1) require.NoError(t, err) _, err = tp1.DecodeToken(t2) require.Error(t, err) tp3 := newTokenProtector(key1) _, err = tp3.DecodeToken(t1) require.NoError(t, err) _, err = tp3.DecodeToken(t2) require.Error(t, err) } func TestTokenProtectorInvalidTokens(t *testing.T) { var key TokenProtectorKey rand.Read(key[:]) tp := newTokenProtector(key) token, err := tp.NewToken([]byte("foobar")) require.NoError(t, err) _, err = tp.DecodeToken(token[1:]) require.Error(t, err) require.Contains(t, err.Error(), "message authentication failed") } func TestTokenProtectorTooShortTokens(t *testing.T) { var key TokenProtectorKey rand.Read(key[:]) tp := newTokenProtector(key) _, err := tp.DecodeToken([]byte("foobar")) require.EqualError(t, err, "token too short: 6") } quic-go-0.59.0/internal/handshake/updatable_aead.go000066400000000000000000000306721513066070600221710ustar00rootroot00000000000000package handshake import ( "crypto" "crypto/cipher" "crypto/tls" "encoding/binary" "fmt" "sync/atomic" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) var keyUpdateInterval atomic.Uint64 func init() { keyUpdateInterval.Store(protocol.KeyUpdateInterval) } func SetKeyUpdateInterval(v uint64) (reset func()) { old := keyUpdateInterval.Swap(v) return func() { keyUpdateInterval.Store(old) } } // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. // It's a package-level variable to allow modifying it for testing purposes. var FirstKeyUpdateInterval uint64 = 100 type updatableAEAD struct { suite cipherSuite keyPhase protocol.KeyPhase largestAcked protocol.PacketNumber firstPacketNumber protocol.PacketNumber handshakeConfirmed bool invalidPacketLimit uint64 invalidPacketCount uint64 // Time when the keys should be dropped. Keys are dropped on the next call to Open(). prevRcvAEADExpiry monotime.Time prevRcvAEAD cipher.AEAD firstRcvdWithCurrentKey protocol.PacketNumber firstSentWithCurrentKey protocol.PacketNumber highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) numRcvdWithCurrentKey uint64 numSentWithCurrentKey uint64 rcvAEAD cipher.AEAD sendAEAD cipher.AEAD // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). aeadOverhead int nextRcvAEAD cipher.AEAD nextSendAEAD cipher.AEAD nextRcvTrafficSecret []byte nextSendTrafficSecret []byte headerDecrypter headerProtector headerEncrypter headerProtector rttStats *utils.RTTStats qlogger qlogwriter.Recorder logger utils.Logger version protocol.Version // use a single slice to avoid allocations nonceBuf []byte } var ( _ ShortHeaderOpener = &updatableAEAD{} _ ShortHeaderSealer = &updatableAEAD{} ) func newUpdatableAEAD(rttStats *utils.RTTStats, qlogger qlogwriter.Recorder, logger utils.Logger, version protocol.Version) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, rttStats: rttStats, qlogger: qlogger, logger: logger, version: version, } } func (a *updatableAEAD) rollKeys() { if a.prevRcvAEAD != nil { a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) if a.qlogger != nil { a.qlogger.RecordEvent(qlog.KeyDiscarded{ KeyType: qlog.KeyTypeClient1RTT, KeyPhase: a.keyPhase - 1, }) a.qlogger.RecordEvent(qlog.KeyDiscarded{ KeyType: qlog.KeyTypeServer1RTT, KeyPhase: a.keyPhase - 1, }) } a.prevRcvAEADExpiry = 0 } a.keyPhase++ a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber a.numRcvdWithCurrentKey = 0 a.numSentWithCurrentKey = 0 a.prevRcvAEAD = a.rcvAEAD a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) } func (a *updatableAEAD) startKeyDropTimer(now monotime.Time) { d := 3 * a.rttStats.PTO(true) a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) a.prevRcvAEADExpiry = now.Add(d) } func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) } // SetReadKey sets the read key. // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) if a.suite.ID == 0 { // suite is not set yet a.setAEADParameters(a.rcvAEAD, suite) } a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) } // SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. // For the server, this function is called before SetReadKey. func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret, a.version) a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) if a.suite.ID == 0 { // suite is not set yet a.setAEADParameters(a.sendAEAD, suite) } a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) } func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite cipherSuite) { a.nonceBuf = make([]byte, aead.NonceSize()) a.aeadOverhead = aead.Overhead() a.suite = suite switch suite.ID { case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: a.invalidPacketLimit = protocol.InvalidPacketLimitAES case tls.TLS_CHACHA20_POLY1305_SHA256: a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha default: panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) } } func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) } func (a *updatableAEAD) Open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { dec, err := a.open(dst, src, rcvTime, pn, kp, ad) if err == ErrDecryptionFailed { a.invalidPacketCount++ if a.invalidPacketCount >= a.invalidPacketLimit { return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} } } if err == nil { a.highestRcvdPN = max(a.highestRcvdPN, pn) } return dec, err } func (a *updatableAEAD) open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { a.prevRcvAEAD = nil a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.prevRcvAEADExpiry = 0 if a.qlogger != nil { a.qlogger.RecordEvent(qlog.KeyDiscarded{ KeyType: qlog.KeyTypeClient1RTT, KeyPhase: a.keyPhase - 1, }) a.qlogger.RecordEvent(qlog.KeyDiscarded{ KeyType: qlog.KeyTypeServer1RTT, KeyPhase: a.keyPhase - 1, }) } } binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) if kp != a.keyPhase.Bit() { if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { if a.prevRcvAEAD == nil { return nil, ErrKeysDropped } // we updated the key, but the peer hasn't updated yet dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) if err != nil { err = ErrDecryptionFailed } return dec, err } // try opening the packet with the next key phase dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) if err != nil { return nil, ErrDecryptionFailed } // Opening succeeded. Check if the peer was allowed to update. if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { return nil, &qerr.TransportError{ ErrorCode: qerr.KeyUpdateError, ErrorMessage: "keys updated too quickly", } } a.rollKeys() a.logger.Debugf("Peer updated keys to %d", a.keyPhase) // The peer initiated this key update. It's safe to drop the keys for the previous generation now. // Start a timer to drop the previous key generation. a.startKeyDropTimer(rcvTime) if a.qlogger != nil { a.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateRemote, KeyType: qlog.KeyTypeClient1RTT, KeyPhase: a.keyPhase, }) a.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateRemote, KeyType: qlog.KeyTypeServer1RTT, KeyPhase: a.keyPhase, }) } a.firstRcvdWithCurrentKey = pn return dec, err } // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) if err != nil { return dec, ErrDecryptionFailed } a.numRcvdWithCurrentKey++ if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { // We initiated the key updated, and now we received the first packet protected with the new key phase. // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. if a.keyPhase > 0 { a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) a.startKeyDropTimer(rcvTime) } a.firstRcvdWithCurrentKey = pn } return dec, err } func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { a.firstSentWithCurrentKey = pn } if a.firstPacketNumber == protocol.InvalidPacketNumber { a.firstPacketNumber = pn } a.numSentWithCurrentKey++ binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) // The AEAD we're using here will be the qtls.aeadAESGCM13. // It uses the nonce provided here and XOR it with the IV. return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) } func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { return &qerr.TransportError{ ErrorCode: qerr.KeyUpdateError, ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), } } a.largestAcked = pn return nil } func (a *updatableAEAD) SetHandshakeConfirmed() { a.handshakeConfirmed = true } func (a *updatableAEAD) updateAllowed() bool { if !a.handshakeConfirmed { return false } // the first key update is allowed as soon as the handshake is confirmed return a.keyPhase == 0 || // subsequent key updates as soon as a packet sent with that key phase has been acknowledged (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && a.largestAcked != protocol.InvalidPacketNumber && a.largestAcked >= a.firstSentWithCurrentKey) } func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism. if a.keyPhase == 0 { if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval { return true } } if a.numRcvdWithCurrentKey >= keyUpdateInterval.Load() { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } if a.numSentWithCurrentKey >= keyUpdateInterval.Load() { a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) return true } return false } func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() if a.qlogger != nil { a.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateLocal, KeyType: qlog.KeyTypeClient1RTT, KeyPhase: a.keyPhase, }) a.qlogger.RecordEvent(qlog.KeyUpdated{ Trigger: qlog.KeyUpdateLocal, KeyType: qlog.KeyTypeServer1RTT, KeyPhase: a.keyPhase, }) } } return a.keyPhase.Bit() } func (a *updatableAEAD) Overhead() int { return a.aeadOverhead } func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) } func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) } func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { return a.firstPacketNumber } quic-go-0.59.0/internal/handshake/updatable_aead_test.go000066400000000000000000000635311513066070600232300ustar00rootroot00000000000000package handshake import ( "crypto/rand" "crypto/tls" "fmt" mrand "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) const ( msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." ad = "Donec in velit neque." ) func randomCipherSuite() cipherSuite { return cipherSuites[mrand.IntN(len(cipherSuites))] } func setupEndpoints(t *testing.T, serverRTTStats *utils.RTTStats) (client, server *updatableAEAD, serverEventRecorder *events.Recorder) { cs := randomCipherSuite() var eventRecorder events.Recorder trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) client = newUpdatableAEAD(utils.NewRTTStats(), nil, utils.DefaultLogger, protocol.Version1) server = newUpdatableAEAD(serverRTTStats, &eventRecorder, utils.DefaultLogger, protocol.Version1) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) server.SetWriteKey(cs, trafficSecret2) return client, server, &eventRecorder } func bothSides(ev qlogwriter.Event) []qlogwriter.Event { switch ev := ev.(type) { case qlog.KeyDiscarded: return []qlogwriter.Event{ qlog.KeyDiscarded{ KeyType: qlog.KeyTypeClient1RTT, KeyPhase: ev.KeyPhase, }, qlog.KeyDiscarded{ KeyType: qlog.KeyTypeServer1RTT, KeyPhase: ev.KeyPhase, }, } case qlog.KeyUpdated: return []qlogwriter.Event{ qlog.KeyUpdated{ KeyType: qlog.KeyTypeClient1RTT, KeyPhase: ev.KeyPhase, Trigger: ev.Trigger, }, qlog.KeyUpdated{ KeyType: qlog.KeyTypeServer1RTT, KeyPhase: ev.KeyPhase, Trigger: ev.Trigger, }, } default: panic("unexpected event type: " + ev.Name()) } } func TestChaChaTestVector(t *testing.T) { testCases := []struct { name string version protocol.Version expectedPayload []byte expectedPacket []byte }{ { version: protocol.Version1, expectedPayload: splitHexString(t, "655e5cd55c41f69080575d7999c25a5bfb"), expectedPacket: splitHexString(t, "4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), }, { version: protocol.Version2, expectedPayload: splitHexString(t, "0ae7b6b932bc27d786f4bc2bb20f2162ba"), expectedPacket: splitHexString(t, "5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), }, } for _, tc := range testCases { t.Run(fmt.Sprintf("QUIC %s", tc.version), func(t *testing.T) { secret := splitHexString(t, "9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") aead := newUpdatableAEAD(utils.NewRTTStats(), nil, nil, tc.version) chacha := cipherSuites[2] require.Equal(t, tls.TLS_CHACHA20_POLY1305_SHA256, chacha.ID) aead.SetWriteKey(chacha, secret) const pnOffset = 1 header := splitHexString(t, "4200bff4") payloadOffset := len(header) plaintext := splitHexString(t, "01") payload := aead.Seal(nil, plaintext, 654360564, header) require.Equal(t, tc.expectedPayload, payload) packet := append(header, payload...) aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) require.Equal(t, tc.expectedPacket, packet) }) } } func TestUpdatableAEADHeaderProtection(t *testing.T) { for _, v := range []protocol.Version{protocol.Version1, protocol.Version2} { for _, cs := range cipherSuites { t.Run(fmt.Sprintf("QUIC %s/%s", v, tls.CipherSuiteName(cs.ID)), func(t *testing.T) { trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) client := newUpdatableAEAD(utils.NewRTTStats(), nil, utils.DefaultLogger, v) server := newUpdatableAEAD(utils.NewRTTStats(), nil, utils.DefaultLogger, v) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) server.SetWriteKey(cs, trafficSecret2) var lastFiveBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) rand.Read(sample) header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} client.EncryptHeader(sample, &header[0], header[9:13]) if header[0]&0x1f != 0xb5&0x1f { lastFiveBitsDifferent++ } require.Equal(t, byte(0xb5&0xe0), header[0]&0xe0) require.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8}, header[1:9]) require.NotEqual(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[9:13]) server.DecryptHeader(sample, &header[0], header[9:13]) require.Equal(t, []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}, header) } require.Greater(t, lastFiveBitsDifferent, 75) }) } } } func TestUpdatableAEADEncryptDecryptMessage(t *testing.T) { for _, v := range []protocol.Version{protocol.Version1, protocol.Version2} { for _, cs := range cipherSuites { t.Run(fmt.Sprintf("QUIC %s/%s", v, tls.CipherSuiteName(cs.ID)), func(t *testing.T) { rttStats := utils.RTTStats{} trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) client := newUpdatableAEAD(&rttStats, nil, utils.DefaultLogger, v) server := newUpdatableAEAD(&rttStats, nil, utils.DefaultLogger, v) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) server.SetWriteKey(cs, trafficSecret2) msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad := []byte("Donec in velit neque.") encrypted := server.Seal(nil, msg, 0x1337, ad) opened, err := client.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseZero, ad) require.NoError(t, err) require.Equal(t, msg, opened) _, err = client.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) require.Equal(t, ErrDecryptionFailed, err) _, err = client.Open(nil, encrypted, monotime.Now(), 0x42, protocol.KeyPhaseZero, ad) require.Equal(t, ErrDecryptionFailed, err) }) } } } func TestUpdatableAEADPacketNumbers(t *testing.T) { client, server, _ := setupEndpoints(t, utils.NewRTTStats()) msg := []byte("Lorem ipsum") ad := []byte("Donec in velit neque.") encrypted := server.Seal(nil, msg, 0x1337, ad) require.Equal(t, protocol.PacketNumber(0x1337), server.FirstPacketNumber()) // make sure we save the first packet number _ = server.Seal(nil, msg, 0x1338, ad) require.Equal(t, protocol.PacketNumber(0x1337), server.FirstPacketNumber()) // make sure we save the first packet number // check that decoding the packet number works as expected _, err := client.Open(nil, encrypted[:len(encrypted)-1], monotime.Now(), 0x1337, protocol.KeyPhaseZero, ad) require.Error(t, err) require.Equal(t, protocol.PacketNumber(0x38), client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)) _, err = client.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseZero, ad) require.NoError(t, err) require.Equal(t, protocol.PacketNumber(0x1338), client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)) } func TestAEADLimitReached(t *testing.T) { client, _, _ := setupEndpoints(t, utils.NewRTTStats()) client.invalidPacketLimit = 10 for i := 0; i < 9; i++ { _, err := client.Open(nil, []byte("foobar"), monotime.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) require.Equal(t, ErrDecryptionFailed, err) } _, err := client.Open(nil, []byte("foobar"), monotime.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) require.Error(t, err) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.AEADLimitReached, transportErr.ErrorCode) } func TestKeyUpdates(t *testing.T) { client, server, _ := setupEndpoints(t, utils.NewRTTStats()) now := monotime.Now() require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) encrypted0 := server.Seal(nil, []byte(msg), 0x1337, []byte(ad)) server.rollKeys() require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) encrypted1 := server.Seal(nil, []byte(msg), 0x1337, []byte(ad)) require.NotEqual(t, encrypted0, encrypted1) _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, []byte(ad)) require.Equal(t, ErrDecryptionFailed, err) client.rollKeys() decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, []byte(ad)) require.NoError(t, err) require.Equal(t, msg, string(decrypted)) } // func TestUpdatesKeysWhenReceivingPacketWithNextKeyPhase(t *testing.T) { // rttStats := utils.RTTStats{} // mockCtrl := gomock.NewController(t) // serverTracer := mocklogging.NewMockConnectionTracer(mockCtrl) // trafficSecret1 := make([]byte, 16) // trafficSecret2 := make([]byte, 16) // rand.Read(trafficSecret1) // rand.Read(trafficSecret2) // client := newUpdatableAEAD(&rttStats, nil, utils.DefaultLogger, protocol.Version1) // server := newUpdatableAEAD(&rttStats, serverTracer, utils.DefaultLogger, protocol.Version1) // client.SetReadKey(cs, trafficSecret2) // client.SetWriteKey(cs, trafficSecret1) // server.SetReadKey(cs, trafficSecret1) // server.SetWriteKey(cs, trafficSecret2) // now := monotime.Now() // encrypted0 := client.Seal(nil, []byte(msg), 0x42, ad) // decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) // require.NoError(t, err) // require.Equal(t, msg, decrypted) // require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) // _ = server.Seal(nil, msg, 0x1, ad) // client.rollKeys() // encrypted1 := client.Seal(nil, msg, 0x43, ad) // serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) // decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) // require.NoError(t, err) // require.Equal(t, msg, decrypted) // require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) // } func TestReorderedPacketAfterKeyUpdate(t *testing.T) { client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats()) now := monotime.Now() encrypted01 := client.Seal(nil, []byte(msg), 0x42, []byte(ad)) encrypted02 := client.Seal(nil, []byte(msg), 0x43, []byte(ad)) _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, []byte(ad)) require.NoError(t, err) _ = server.Seal(nil, []byte(msg), 0x1, []byte(ad)) client.rollKeys() encrypted1 := client.Seal(nil, []byte(msg), 0x44, []byte(ad)) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, []byte(ad)) require.NoError(t, err) require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{Trigger: qlog.KeyUpdateRemote, KeyPhase: 1}), eventRecorder.Events(), ) // now receive a reordered packet decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, []byte(ad)) require.NoError(t, err) require.Equal(t, msg, string(decrypted)) require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) } func TestDropsKeys3PTOsAfterKeyUpdate(t *testing.T) { rttStats := utils.NewRTTStats() client, server, eventRecorder := setupEndpoints(t, rttStats) now := monotime.Now() rttStats.UpdateRTT(10*time.Millisecond, 0) pto := rttStats.PTO(true) encrypted01 := client.Seal(nil, []byte(msg), 0x42, []byte(ad)) encrypted02 := client.Seal(nil, []byte(msg), 0x43, []byte(ad)) _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, []byte(ad)) require.NoError(t, err) _ = server.Seal(nil, []byte(msg), 0x1, []byte(ad)) client.rollKeys() encrypted1 := client.Seal(nil, []byte(msg), 0x44, []byte(ad)) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, []byte(ad)) require.NoError(t, err) require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateRemote}), eventRecorder.Events(), ) eventRecorder.Clear() // packet arrived too late, the key was already dropped _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, []byte(ad)) require.Equal(t, ErrKeysDropped, err) require.Equal(t, bothSides(qlog.KeyDiscarded{KeyPhase: 0}), eventRecorder.Events(), ) } func TestAllowsFirstKeyUpdateImmediately(t *testing.T) { client, server, serverTracer := setupEndpoints(t, utils.NewRTTStats()) client.rollKeys() encrypted := client.Seal(nil, []byte(msg), 0x1337, []byte(ad)) // if decryption failed, we don't expect a key phase update _, err := server.Open(nil, encrypted[:len(encrypted)-1], monotime.Now(), 0x1337, protocol.KeyPhaseOne, []byte(ad)) require.Equal(t, ErrDecryptionFailed, err) // the key phase is updated on first successful decryption _, err = server.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseOne, []byte(ad)) require.NoError(t, err) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateRemote}), serverTracer.Events(), ) } func TestRejectFrequentKeyUpdates(t *testing.T) { client, server, _ := setupEndpoints(t, utils.NewRTTStats()) server.rollKeys() client.rollKeys() encrypted0 := client.Seal(nil, []byte(msg), 0x42, []byte(ad)) _, err := server.Open(nil, encrypted0, monotime.Now(), 0x42, protocol.KeyPhaseOne, []byte(ad)) require.NoError(t, err) client.rollKeys() encrypted1 := client.Seal(nil, []byte(msg), 0x42, []byte(ad)) _, err = server.Open(nil, encrypted1, monotime.Now(), 0x42, protocol.KeyPhaseZero, []byte(ad)) require.Equal(t, &qerr.TransportError{ ErrorCode: qerr.KeyUpdateError, ErrorMessage: "keys updated too quickly", }, err) } func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) { reset := SetKeyUpdateInterval(keyUpdateInterval) t.Cleanup(reset) origFirstKeyUpdateInterval := FirstKeyUpdateInterval FirstKeyUpdateInterval = firstKeyUpdateInterval t.Cleanup(func() { FirstKeyUpdateInterval = origFirstKeyUpdateInterval }) } func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) { const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval) client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats()) server.SetHandshakeConfirmed() var pn protocol.PacketNumber // first key update for range firstKeyUpdateInterval { require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) pn++ } // the first update is allowed without receiving an acknowledgement require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}), eventRecorder.Events(), ) eventRecorder.Clear() // subsequent key update for range 2 * keyUpdateInterval { require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) pn++ } // no update allowed before receiving an acknowledgement for the current key phase require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) // receive an ACK for a packet sent in key phase 1 client.rollKeys() b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) _, err := server.Open(nil, b, monotime.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) require.NoError(t, err) require.NoError(t, server.SetLargestAcked(firstKeyUpdateInterval)) require.Empty(t, eventRecorder.Events()) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) require.Equal(t, append( bothSides(qlog.KeyDiscarded{KeyPhase: 0}), bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateLocal})..., ), eventRecorder.Events(), ) } func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) { const firstKeyUpdateInterval = 5 setKeyUpdateIntervals(t, firstKeyUpdateInterval, protocol.KeyUpdateInterval) _, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats()) server.SetHandshakeConfirmed() // First make sure that we update our keys. for i := range firstKeyUpdateInterval { pn := protocol.PacketNumber(i) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) } require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}), eventRecorder.Events(), ) eventRecorder.Clear() // Now that our keys are updated, send a packet using the new keys. const nextPN = firstKeyUpdateInterval + 1 server.Seal(nil, []byte(msg), nextPN, []byte(ad)) for i := range firstKeyUpdateInterval { // We haven't decrypted any packet in the new key phase yet. // This means that the ACK must have been sent in the old key phase. require.NoError(t, server.SetLargestAcked(protocol.PacketNumber(i))) } // We haven't decrypted any packet in the new key phase yet. // This means that the ACK must have been sent in the old key phase. err := server.SetLargestAcked(nextPN) require.Error(t, err) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.KeyUpdateError, transportErr.ErrorCode) require.Equal(t, "received ACK for key phase 1, but peer didn't update keys", transportErr.ErrorMessage) require.Empty(t, eventRecorder.Events()) } func TestKeyUpdateAfterOpeningMaxPackets(t *testing.T) { const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval) client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats()) server.SetHandshakeConfirmed() msg := []byte("message") ad := []byte("additional data") // first key update var pn protocol.PacketNumber for range firstKeyUpdateInterval { require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) encrypted := client.Seal(nil, msg, pn, ad) _, err := server.Open(nil, encrypted, monotime.Now(), pn, protocol.KeyPhaseZero, ad) require.NoError(t, err) pn++ } // the first update is allowed without receiving an acknowledgement require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}), eventRecorder.Events(), ) eventRecorder.Clear() // subsequent key update client.rollKeys() for range keyUpdateInterval { require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) encrypted := client.Seal(nil, msg, pn, ad) _, err := server.Open(nil, encrypted, monotime.Now(), pn, protocol.KeyPhaseOne, ad) require.NoError(t, err) pn++ } // No update allowed before receiving an acknowledgement for the current key phase require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) server.Seal(nil, msg, 1, ad) require.NoError(t, server.SetLargestAcked(firstKeyUpdateInterval+1)) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) require.Equal(t, append( bothSides(qlog.KeyDiscarded{KeyPhase: 0}), bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateLocal})..., ), eventRecorder.Events(), ) } func TestKeyUpdateKeyPhaseSkipping(t *testing.T) { const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval) rttStats := utils.NewRTTStats() rttStats.UpdateRTT(10*time.Millisecond, 0) client, server, eventRecorder := setupEndpoints(t, rttStats) server.SetHandshakeConfirmed() now := monotime.Now() data1 := client.Seal(nil, []byte(msg), 1, []byte(ad)) _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, []byte(ad)) require.NoError(t, err) for i := range firstKeyUpdateInterval { pn := protocol.PacketNumber(i) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) require.NoError(t, server.SetLargestAcked(pn)) } require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}), eventRecorder.Events(), ) eventRecorder.Clear() // The server never received a packet at key phase 1. // Make sure the key phase 0 is still there at a much later point. data2 := client.Seal(nil, []byte(msg), 1, []byte(ad)) _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, []byte(ad)) require.NoError(t, err) require.Empty(t, eventRecorder.Events()) } func TestFastKeyUpdatesByPeer(t *testing.T) { const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval) client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats()) server.SetHandshakeConfirmed() var pn protocol.PacketNumber for range firstKeyUpdateInterval { require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) pn++ } b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) _, err := server.Open(nil, b, monotime.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) require.NoError(t, err) require.NoError(t, server.SetLargestAcked(0)) require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}), eventRecorder.Events(), ) eventRecorder.Clear() // Send and receive an acknowledgement for a packet in key phase 1. // We are now running a timer to drop the keys with 3 PTO. server.Seal(nil, []byte(msg), pn, []byte(ad)) client.rollKeys() dataKeyPhaseOne := client.Seal(nil, []byte(msg), 2, []byte(ad)) now := monotime.Now() _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, []byte(ad)) require.NoError(t, err) require.NoError(t, server.SetLargestAcked(pn)) // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. // This mean that we need to drop the keys for key phase 0 immediately. client.rollKeys() dataKeyPhaseTwo := client.Seal(nil, []byte(msg), 3, []byte(ad)) _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, []byte(ad)) require.NoError(t, err) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) require.Equal(t, append( bothSides(qlog.KeyDiscarded{KeyPhase: 0}), bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateRemote})..., ), eventRecorder.Events(), ) } func TestFastKeyUpdateByUs(t *testing.T) { const firstKeyUpdateInterval = 5 const keyUpdateInterval = 20 setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval) rttStats := utils.NewRTTStats() rttStats.UpdateRTT(10*time.Millisecond, 0) client, server, eventRecorder := setupEndpoints(t, rttStats) server.SetHandshakeConfirmed() // send so many packets that we initiate the first key update for i := range firstKeyUpdateInterval { pn := protocol.PacketNumber(i) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) } b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) _, err := server.Open(nil, b, monotime.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) require.NoError(t, err) require.NoError(t, server.SetLargestAcked(0)) require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) require.Equal(t, bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}), eventRecorder.Events(), ) eventRecorder.Clear() // send so many packets that we initiate the next key update for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase()) server.Seal(nil, []byte(msg), pn, []byte(ad)) } client.rollKeys() b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) now := monotime.Now() _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) require.NoError(t, err) require.NoError(t, server.SetLargestAcked(keyUpdateInterval)) require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase()) require.Equal(t, append( bothSides(qlog.KeyDiscarded{KeyPhase: 0}), bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateLocal})..., ), eventRecorder.Events(), ) eventRecorder.Clear() // We haven't received an ACK for a packet sent in key phase 2 yet. // Make sure we canceled the timer to drop the previous key phase. b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) require.NoError(t, err) require.Empty(t, eventRecorder.Events()) } func getClientAndServer() (client, server *updatableAEAD) { trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) cs := cipherSuites[0] rttStats := utils.NewRTTStats() client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) server.SetWriteKey(cs, trafficSecret2) return } func BenchmarkPacketEncryption(b *testing.B) { client, _ := getClientAndServer() const l = 1200 src := make([]byte, l) rand.Read(src) ad := make([]byte, 32) rand.Read(ad) var pn protocol.PacketNumber for b.Loop() { src = client.Seal(src[:0], src[:l], pn, ad) pn++ } } func BenchmarkPacketDecryption(b *testing.B) { client, server := getClientAndServer() const l = 1200 src := make([]byte, l) dst := make([]byte, l) rand.Read(src) ad := make([]byte, 32) rand.Read(ad) src = client.Seal(src[:0], src[:l], 1337, ad) for b.Loop() { if _, err := server.Open(dst[:0], src, 0, 1337, protocol.KeyPhaseZero, ad); err != nil { b.Fatalf("opening failed: %v", err) } } } func BenchmarkRollKeys(b *testing.B) { client, _ := getClientAndServer() for b.Loop() { client.rollKeys() } if int(client.keyPhase) != b.N { b.Fatal("didn't roll keys often enough") } } quic-go-0.59.0/internal/mocks/000077500000000000000000000000001513066070600161055ustar00rootroot00000000000000quic-go-0.59.0/internal/mocks/ackhandler/000077500000000000000000000000001513066070600202015ustar00rootroot00000000000000quic-go-0.59.0/internal/mocks/ackhandler/sent_packet_handler.go000066400000000000000000000606331513066070600245350ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: SentPacketHandler) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler // // Package mockackhandler is a generated GoMock package. package mockackhandler import ( reflect "reflect" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" wire "github.com/quic-go/quic-go/internal/wire" gomock "go.uber.org/mock/gomock" ) // MockSentPacketHandler is a mock of SentPacketHandler interface. type MockSentPacketHandler struct { ctrl *gomock.Controller recorder *MockSentPacketHandlerMockRecorder isgomock struct{} } // MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler. type MockSentPacketHandlerMockRecorder struct { mock *MockSentPacketHandler } // NewMockSentPacketHandler creates a new mock instance. func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { mock := &MockSentPacketHandler{ctrl: ctrl} mock.recorder = &MockSentPacketHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { return m.recorder } // DropPackets mocks base method. func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel, rcvTime monotime.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "DropPackets", arg0, rcvTime) } // DropPackets indicates an expected call of DropPackets. func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0, rcvTime any) *MockSentPacketHandlerDropPacketsCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0, rcvTime) return &MockSentPacketHandlerDropPacketsCall{Call: call} } // MockSentPacketHandlerDropPacketsCall wrap *gomock.Call type MockSentPacketHandlerDropPacketsCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerDropPacketsCall) Return() *MockSentPacketHandlerDropPacketsCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerDropPacketsCall) Do(f func(protocol.EncryptionLevel, monotime.Time)) *MockSentPacketHandlerDropPacketsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerDropPacketsCall) DoAndReturn(f func(protocol.EncryptionLevel, monotime.Time)) *MockSentPacketHandlerDropPacketsCall { c.Call = c.Call.DoAndReturn(f) return c } // ECNMode mocks base method. func (m *MockSentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ECNMode", isShortHeaderPacket) ret0, _ := ret[0].(protocol.ECN) return ret0 } // ECNMode indicates an expected call of ECNMode. func (mr *MockSentPacketHandlerMockRecorder) ECNMode(isShortHeaderPacket any) *MockSentPacketHandlerECNModeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), isShortHeaderPacket) return &MockSentPacketHandlerECNModeCall{Call: call} } // MockSentPacketHandlerECNModeCall wrap *gomock.Call type MockSentPacketHandlerECNModeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerECNModeCall) Return(arg0 protocol.ECN) *MockSentPacketHandlerECNModeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerECNModeCall) Do(f func(bool) protocol.ECN) *MockSentPacketHandlerECNModeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerECNModeCall) DoAndReturn(f func(bool) protocol.ECN) *MockSentPacketHandlerECNModeCall { c.Call = c.Call.DoAndReturn(f) return c } // GetLossDetectionTimeout mocks base method. func (m *MockSentPacketHandler) GetLossDetectionTimeout() monotime.Time { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLossDetectionTimeout") ret0, _ := ret[0].(monotime.Time) return ret0 } // GetLossDetectionTimeout indicates an expected call of GetLossDetectionTimeout. func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *MockSentPacketHandlerGetLossDetectionTimeoutCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) return &MockSentPacketHandlerGetLossDetectionTimeoutCall{Call: call} } // MockSentPacketHandlerGetLossDetectionTimeoutCall wrap *gomock.Call type MockSentPacketHandlerGetLossDetectionTimeoutCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) Return(arg0 monotime.Time) *MockSentPacketHandlerGetLossDetectionTimeoutCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) Do(f func() monotime.Time) *MockSentPacketHandlerGetLossDetectionTimeoutCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) DoAndReturn(f func() monotime.Time) *MockSentPacketHandlerGetLossDetectionTimeoutCall { c.Call = c.Call.DoAndReturn(f) return c } // MigratedPath mocks base method. func (m *MockSentPacketHandler) MigratedPath(now monotime.Time, initialMaxPacketSize protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "MigratedPath", now, initialMaxPacketSize) } // MigratedPath indicates an expected call of MigratedPath. func (mr *MockSentPacketHandlerMockRecorder) MigratedPath(now, initialMaxPacketSize any) *MockSentPacketHandlerMigratedPathCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MigratedPath", reflect.TypeOf((*MockSentPacketHandler)(nil).MigratedPath), now, initialMaxPacketSize) return &MockSentPacketHandlerMigratedPathCall{Call: call} } // MockSentPacketHandlerMigratedPathCall wrap *gomock.Call type MockSentPacketHandlerMigratedPathCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerMigratedPathCall) Return() *MockSentPacketHandlerMigratedPathCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerMigratedPathCall) Do(f func(monotime.Time, protocol.ByteCount)) *MockSentPacketHandlerMigratedPathCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerMigratedPathCall) DoAndReturn(f func(monotime.Time, protocol.ByteCount)) *MockSentPacketHandlerMigratedPathCall { c.Call = c.Call.DoAndReturn(f) return c } // OnLossDetectionTimeout mocks base method. func (m *MockSentPacketHandler) OnLossDetectionTimeout(now monotime.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OnLossDetectionTimeout", now) ret0, _ := ret[0].(error) return ret0 } // OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout(now any) *MockSentPacketHandlerOnLossDetectionTimeoutCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout), now) return &MockSentPacketHandlerOnLossDetectionTimeoutCall{Call: call} } // MockSentPacketHandlerOnLossDetectionTimeoutCall wrap *gomock.Call type MockSentPacketHandlerOnLossDetectionTimeoutCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Return(arg0 error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Do(f func(monotime.Time) error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) DoAndReturn(f func(monotime.Time) error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { c.Call = c.Call.DoAndReturn(f) return c } // PeekPacketNumber mocks base method. func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PeekPacketNumber", arg0) ret0, _ := ret[0].(protocol.PacketNumber) ret1, _ := ret[1].(protocol.PacketNumberLen) return ret0, ret1 } // PeekPacketNumber indicates an expected call of PeekPacketNumber. func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 any) *MockSentPacketHandlerPeekPacketNumberCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) return &MockSentPacketHandlerPeekPacketNumberCall{Call: call} } // MockSentPacketHandlerPeekPacketNumberCall wrap *gomock.Call type MockSentPacketHandlerPeekPacketNumberCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerPeekPacketNumberCall) Return(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) *MockSentPacketHandlerPeekPacketNumberCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerPeekPacketNumberCall) Do(f func(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)) *MockSentPacketHandlerPeekPacketNumberCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerPeekPacketNumberCall) DoAndReturn(f func(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)) *MockSentPacketHandlerPeekPacketNumberCall { c.Call = c.Call.DoAndReturn(f) return c } // PopPacketNumber mocks base method. func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) protocol.PacketNumber { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PopPacketNumber", arg0) ret0, _ := ret[0].(protocol.PacketNumber) return ret0 } // PopPacketNumber indicates an expected call of PopPacketNumber. func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 any) *MockSentPacketHandlerPopPacketNumberCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) return &MockSentPacketHandlerPopPacketNumberCall{Call: call} } // MockSentPacketHandlerPopPacketNumberCall wrap *gomock.Call type MockSentPacketHandlerPopPacketNumberCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerPopPacketNumberCall) Return(arg0 protocol.PacketNumber) *MockSentPacketHandlerPopPacketNumberCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerPopPacketNumberCall) Do(f func(protocol.EncryptionLevel) protocol.PacketNumber) *MockSentPacketHandlerPopPacketNumberCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerPopPacketNumberCall) DoAndReturn(f func(protocol.EncryptionLevel) protocol.PacketNumber) *MockSentPacketHandlerPopPacketNumberCall { c.Call = c.Call.DoAndReturn(f) return c } // QueueProbePacket mocks base method. func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueueProbePacket", arg0) ret0, _ := ret[0].(bool) return ret0 } // QueueProbePacket indicates an expected call of QueueProbePacket. func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 any) *MockSentPacketHandlerQueueProbePacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) return &MockSentPacketHandlerQueueProbePacketCall{Call: call} } // MockSentPacketHandlerQueueProbePacketCall wrap *gomock.Call type MockSentPacketHandlerQueueProbePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerQueueProbePacketCall) Return(arg0 bool) *MockSentPacketHandlerQueueProbePacketCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerQueueProbePacketCall) Do(f func(protocol.EncryptionLevel) bool) *MockSentPacketHandlerQueueProbePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerQueueProbePacketCall) DoAndReturn(f func(protocol.EncryptionLevel) bool) *MockSentPacketHandlerQueueProbePacketCall { c.Call = c.Call.DoAndReturn(f) return c } // ReceivedAck mocks base method. func (m *MockSentPacketHandler) ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime monotime.Time) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReceivedAck", f, encLevel, rcvTime) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // ReceivedAck indicates an expected call of ReceivedAck. func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(f, encLevel, rcvTime any) *MockSentPacketHandlerReceivedAckCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), f, encLevel, rcvTime) return &MockSentPacketHandlerReceivedAckCall{Call: call} } // MockSentPacketHandlerReceivedAckCall wrap *gomock.Call type MockSentPacketHandlerReceivedAckCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerReceivedAckCall) Return(arg0 bool, arg1 error) *MockSentPacketHandlerReceivedAckCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerReceivedAckCall) Do(f func(*wire.AckFrame, protocol.EncryptionLevel, monotime.Time) (bool, error)) *MockSentPacketHandlerReceivedAckCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerReceivedAckCall) DoAndReturn(f func(*wire.AckFrame, protocol.EncryptionLevel, monotime.Time) (bool, error)) *MockSentPacketHandlerReceivedAckCall { c.Call = c.Call.DoAndReturn(f) return c } // ReceivedBytes mocks base method. func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount, rcvTime monotime.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReceivedBytes", arg0, rcvTime) } // ReceivedBytes indicates an expected call of ReceivedBytes. func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0, rcvTime any) *MockSentPacketHandlerReceivedBytesCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0, rcvTime) return &MockSentPacketHandlerReceivedBytesCall{Call: call} } // MockSentPacketHandlerReceivedBytesCall wrap *gomock.Call type MockSentPacketHandlerReceivedBytesCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerReceivedBytesCall) Return() *MockSentPacketHandlerReceivedBytesCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerReceivedBytesCall) Do(f func(protocol.ByteCount, monotime.Time)) *MockSentPacketHandlerReceivedBytesCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerReceivedBytesCall) DoAndReturn(f func(protocol.ByteCount, monotime.Time)) *MockSentPacketHandlerReceivedBytesCall { c.Call = c.Call.DoAndReturn(f) return c } // ReceivedPacket mocks base method. func (m *MockSentPacketHandler) ReceivedPacket(arg0 protocol.EncryptionLevel, arg1 monotime.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReceivedPacket", arg0, arg1) } // ReceivedPacket indicates an expected call of ReceivedPacket. func (mr *MockSentPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1 any) *MockSentPacketHandlerReceivedPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedPacket), arg0, arg1) return &MockSentPacketHandlerReceivedPacketCall{Call: call} } // MockSentPacketHandlerReceivedPacketCall wrap *gomock.Call type MockSentPacketHandlerReceivedPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerReceivedPacketCall) Return() *MockSentPacketHandlerReceivedPacketCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerReceivedPacketCall) Do(f func(protocol.EncryptionLevel, monotime.Time)) *MockSentPacketHandlerReceivedPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerReceivedPacketCall) DoAndReturn(f func(protocol.EncryptionLevel, monotime.Time)) *MockSentPacketHandlerReceivedPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // ResetForRetry mocks base method. func (m *MockSentPacketHandler) ResetForRetry(rcvTime monotime.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "ResetForRetry", rcvTime) } // ResetForRetry indicates an expected call of ResetForRetry. func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(rcvTime any) *MockSentPacketHandlerResetForRetryCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry), rcvTime) return &MockSentPacketHandlerResetForRetryCall{Call: call} } // MockSentPacketHandlerResetForRetryCall wrap *gomock.Call type MockSentPacketHandlerResetForRetryCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerResetForRetryCall) Return() *MockSentPacketHandlerResetForRetryCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerResetForRetryCall) Do(f func(monotime.Time)) *MockSentPacketHandlerResetForRetryCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerResetForRetryCall) DoAndReturn(f func(monotime.Time)) *MockSentPacketHandlerResetForRetryCall { c.Call = c.Call.DoAndReturn(f) return c } // SendMode mocks base method. func (m *MockSentPacketHandler) SendMode(now monotime.Time) ackhandler.SendMode { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendMode", now) ret0, _ := ret[0].(ackhandler.SendMode) return ret0 } // SendMode indicates an expected call of SendMode. func (mr *MockSentPacketHandlerMockRecorder) SendMode(now any) *MockSentPacketHandlerSendModeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode), now) return &MockSentPacketHandlerSendModeCall{Call: call} } // MockSentPacketHandlerSendModeCall wrap *gomock.Call type MockSentPacketHandlerSendModeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerSendModeCall) Return(arg0 ackhandler.SendMode) *MockSentPacketHandlerSendModeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerSendModeCall) Do(f func(monotime.Time) ackhandler.SendMode) *MockSentPacketHandlerSendModeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerSendModeCall) DoAndReturn(f func(monotime.Time) ackhandler.SendMode) *MockSentPacketHandlerSendModeCall { c.Call = c.Call.DoAndReturn(f) return c } // SentPacket mocks base method. func (m *MockSentPacketHandler) SentPacket(t monotime.Time, pn, largestAcked protocol.PacketNumber, streamFrames []ackhandler.StreamFrame, frames []ackhandler.Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket, isPathProbePacket bool) { m.ctrl.T.Helper() m.ctrl.Call(m, "SentPacket", t, pn, largestAcked, streamFrames, frames, encLevel, ecn, size, isPathMTUProbePacket, isPathProbePacket) } // SentPacket indicates an expected call of SentPacket. func (mr *MockSentPacketHandlerMockRecorder) SentPacket(t, pn, largestAcked, streamFrames, frames, encLevel, ecn, size, isPathMTUProbePacket, isPathProbePacket any) *MockSentPacketHandlerSentPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), t, pn, largestAcked, streamFrames, frames, encLevel, ecn, size, isPathMTUProbePacket, isPathProbePacket) return &MockSentPacketHandlerSentPacketCall{Call: call} } // MockSentPacketHandlerSentPacketCall wrap *gomock.Call type MockSentPacketHandlerSentPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerSentPacketCall) Return() *MockSentPacketHandlerSentPacketCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerSentPacketCall) Do(f func(monotime.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool, bool)) *MockSentPacketHandlerSentPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerSentPacketCall) DoAndReturn(f func(monotime.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool, bool)) *MockSentPacketHandlerSentPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // SetMaxDatagramSize mocks base method. func (m *MockSentPacketHandler) SetMaxDatagramSize(count protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetMaxDatagramSize", count) } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(count any) *MockSentPacketHandlerSetMaxDatagramSizeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), count) return &MockSentPacketHandlerSetMaxDatagramSizeCall{Call: call} } // MockSentPacketHandlerSetMaxDatagramSizeCall wrap *gomock.Call type MockSentPacketHandlerSetMaxDatagramSizeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerSetMaxDatagramSizeCall) Return() *MockSentPacketHandlerSetMaxDatagramSizeCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerSetMaxDatagramSizeCall) Do(f func(protocol.ByteCount)) *MockSentPacketHandlerSetMaxDatagramSizeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerSetMaxDatagramSizeCall) DoAndReturn(f func(protocol.ByteCount)) *MockSentPacketHandlerSetMaxDatagramSizeCall { c.Call = c.Call.DoAndReturn(f) return c } // TimeUntilSend mocks base method. func (m *MockSentPacketHandler) TimeUntilSend() monotime.Time { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TimeUntilSend") ret0, _ := ret[0].(monotime.Time) return ret0 } // TimeUntilSend indicates an expected call of TimeUntilSend. func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *MockSentPacketHandlerTimeUntilSendCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) return &MockSentPacketHandlerTimeUntilSendCall{Call: call} } // MockSentPacketHandlerTimeUntilSendCall wrap *gomock.Call type MockSentPacketHandlerTimeUntilSendCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSentPacketHandlerTimeUntilSendCall) Return(arg0 monotime.Time) *MockSentPacketHandlerTimeUntilSendCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSentPacketHandlerTimeUntilSendCall) Do(f func() monotime.Time) *MockSentPacketHandlerTimeUntilSendCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSentPacketHandlerTimeUntilSendCall) DoAndReturn(f func() monotime.Time) *MockSentPacketHandlerTimeUntilSendCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/mocks/congestion.go000066400000000000000000000473561513066070600206230ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/congestion (interfaces: SendAlgorithmWithDebugInfos) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos // // Package mocks is a generated GoMock package. package mocks import ( reflect "reflect" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockSendAlgorithmWithDebugInfos is a mock of SendAlgorithmWithDebugInfos interface. type MockSendAlgorithmWithDebugInfos struct { ctrl *gomock.Controller recorder *MockSendAlgorithmWithDebugInfosMockRecorder isgomock struct{} } // MockSendAlgorithmWithDebugInfosMockRecorder is the mock recorder for MockSendAlgorithmWithDebugInfos. type MockSendAlgorithmWithDebugInfosMockRecorder struct { mock *MockSendAlgorithmWithDebugInfos } // NewMockSendAlgorithmWithDebugInfos creates a new mock instance. func NewMockSendAlgorithmWithDebugInfos(ctrl *gomock.Controller) *MockSendAlgorithmWithDebugInfos { mock := &MockSendAlgorithmWithDebugInfos{ctrl: ctrl} mock.recorder = &MockSendAlgorithmWithDebugInfosMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockSendAlgorithmWithDebugInfos) EXPECT() *MockSendAlgorithmWithDebugInfosMockRecorder { return m.recorder } // CanSend mocks base method. func (m *MockSendAlgorithmWithDebugInfos) CanSend(bytesInFlight protocol.ByteCount) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CanSend", bytesInFlight) ret0, _ := ret[0].(bool) return ret0 } // CanSend indicates an expected call of CanSend. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(bytesInFlight any) *MockSendAlgorithmWithDebugInfosCanSendCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), bytesInFlight) return &MockSendAlgorithmWithDebugInfosCanSendCall{Call: call} } // MockSendAlgorithmWithDebugInfosCanSendCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosCanSendCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosCanSendCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosCanSendCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosCanSendCall) Do(f func(protocol.ByteCount) bool) *MockSendAlgorithmWithDebugInfosCanSendCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosCanSendCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockSendAlgorithmWithDebugInfosCanSendCall { c.Call = c.Call.DoAndReturn(f) return c } // GetCongestionWindow mocks base method. func (m *MockSendAlgorithmWithDebugInfos) GetCongestionWindow() protocol.ByteCount { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetCongestionWindow") ret0, _ := ret[0].(protocol.ByteCount) return ret0 } // GetCongestionWindow indicates an expected call of GetCongestionWindow. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) return &MockSendAlgorithmWithDebugInfosGetCongestionWindowCall{Call: call} } // MockSendAlgorithmWithDebugInfosGetCongestionWindowCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosGetCongestionWindowCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall) Return(arg0 protocol.ByteCount) *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall) Do(f func() protocol.ByteCount) *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall) DoAndReturn(f func() protocol.ByteCount) *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { c.Call = c.Call.DoAndReturn(f) return c } // HasPacingBudget mocks base method. func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget(now monotime.Time) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HasPacingBudget", now) ret0, _ := ret[0].(bool) return ret0 } // HasPacingBudget indicates an expected call of HasPacingBudget. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(now any) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget), now) return &MockSendAlgorithmWithDebugInfosHasPacingBudgetCall{Call: call} } // MockSendAlgorithmWithDebugInfosHasPacingBudgetCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosHasPacingBudgetCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall) Do(f func(monotime.Time) bool) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall) DoAndReturn(f func(monotime.Time) bool) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { c.Call = c.Call.DoAndReturn(f) return c } // InRecovery mocks base method. func (m *MockSendAlgorithmWithDebugInfos) InRecovery() bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "InRecovery") ret0, _ := ret[0].(bool) return ret0 } // InRecovery indicates an expected call of InRecovery. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *MockSendAlgorithmWithDebugInfosInRecoveryCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) return &MockSendAlgorithmWithDebugInfosInRecoveryCall{Call: call} } // MockSendAlgorithmWithDebugInfosInRecoveryCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosInRecoveryCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosInRecoveryCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosInRecoveryCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosInRecoveryCall) Do(f func() bool) *MockSendAlgorithmWithDebugInfosInRecoveryCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosInRecoveryCall) DoAndReturn(f func() bool) *MockSendAlgorithmWithDebugInfosInRecoveryCall { c.Call = c.Call.DoAndReturn(f) return c } // InSlowStart mocks base method. func (m *MockSendAlgorithmWithDebugInfos) InSlowStart() bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "InSlowStart") ret0, _ := ret[0].(bool) return ret0 } // InSlowStart indicates an expected call of InSlowStart. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *MockSendAlgorithmWithDebugInfosInSlowStartCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) return &MockSendAlgorithmWithDebugInfosInSlowStartCall{Call: call} } // MockSendAlgorithmWithDebugInfosInSlowStartCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosInSlowStartCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosInSlowStartCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosInSlowStartCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosInSlowStartCall) Do(f func() bool) *MockSendAlgorithmWithDebugInfosInSlowStartCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosInSlowStartCall) DoAndReturn(f func() bool) *MockSendAlgorithmWithDebugInfosInSlowStartCall { c.Call = c.Call.DoAndReturn(f) return c } // MaybeExitSlowStart mocks base method. func (m *MockSendAlgorithmWithDebugInfos) MaybeExitSlowStart() { m.ctrl.T.Helper() m.ctrl.Call(m, "MaybeExitSlowStart") } // MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) return &MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall{Call: call} } // MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall) Return() *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall) Do(f func()) *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall) DoAndReturn(f func()) *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { c.Call = c.Call.DoAndReturn(f) return c } // OnCongestionEvent mocks base method. func (m *MockSendAlgorithmWithDebugInfos) OnCongestionEvent(number protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnCongestionEvent", number, lostBytes, priorInFlight) } // OnCongestionEvent indicates an expected call of OnCongestionEvent. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(number, lostBytes, priorInFlight any) *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), number, lostBytes, priorInFlight) return &MockSendAlgorithmWithDebugInfosOnCongestionEventCall{Call: call} } // MockSendAlgorithmWithDebugInfosOnCongestionEventCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosOnCongestionEventCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosOnCongestionEventCall) Return() *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosOnCongestionEventCall) Do(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosOnCongestionEventCall) DoAndReturn(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { c.Call = c.Call.DoAndReturn(f) return c } // OnPacketAcked mocks base method. func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(number protocol.PacketNumber, ackedBytes, priorInFlight protocol.ByteCount, eventTime monotime.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnPacketAcked", number, ackedBytes, priorInFlight, eventTime) } // OnPacketAcked indicates an expected call of OnPacketAcked. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(number, ackedBytes, priorInFlight, eventTime any) *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), number, ackedBytes, priorInFlight, eventTime) return &MockSendAlgorithmWithDebugInfosOnPacketAckedCall{Call: call} } // MockSendAlgorithmWithDebugInfosOnPacketAckedCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosOnPacketAckedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosOnPacketAckedCall) Return() *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosOnPacketAckedCall) Do(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount, monotime.Time)) *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosOnPacketAckedCall) DoAndReturn(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount, monotime.Time)) *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { c.Call = c.Call.DoAndReturn(f) return c } // OnPacketSent mocks base method. func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(sentTime monotime.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnPacketSent", sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable) } // OnPacketSent indicates an expected call of OnPacketSent. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable any) *MockSendAlgorithmWithDebugInfosOnPacketSentCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable) return &MockSendAlgorithmWithDebugInfosOnPacketSentCall{Call: call} } // MockSendAlgorithmWithDebugInfosOnPacketSentCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosOnPacketSentCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosOnPacketSentCall) Return() *MockSendAlgorithmWithDebugInfosOnPacketSentCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosOnPacketSentCall) Do(f func(monotime.Time, protocol.ByteCount, protocol.PacketNumber, protocol.ByteCount, bool)) *MockSendAlgorithmWithDebugInfosOnPacketSentCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosOnPacketSentCall) DoAndReturn(f func(monotime.Time, protocol.ByteCount, protocol.PacketNumber, protocol.ByteCount, bool)) *MockSendAlgorithmWithDebugInfosOnPacketSentCall { c.Call = c.Call.DoAndReturn(f) return c } // OnRetransmissionTimeout mocks base method. func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(packetsRetransmitted bool) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnRetransmissionTimeout", packetsRetransmitted) } // OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(packetsRetransmitted any) *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), packetsRetransmitted) return &MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall{Call: call} } // MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall) Return() *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall) Do(f func(bool)) *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall) DoAndReturn(f func(bool)) *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { c.Call = c.Call.DoAndReturn(f) return c } // SetMaxDatagramSize mocks base method. func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetMaxDatagramSize", arg0) } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 any) *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) return &MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall{Call: call} } // MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall) Return() *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall) Do(f func(protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall) DoAndReturn(f func(protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { c.Call = c.Call.DoAndReturn(f) return c } // TimeUntilSend mocks base method. func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(bytesInFlight protocol.ByteCount) monotime.Time { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TimeUntilSend", bytesInFlight) ret0, _ := ret[0].(monotime.Time) return ret0 } // TimeUntilSend indicates an expected call of TimeUntilSend. func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(bytesInFlight any) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), bytesInFlight) return &MockSendAlgorithmWithDebugInfosTimeUntilSendCall{Call: call} } // MockSendAlgorithmWithDebugInfosTimeUntilSendCall wrap *gomock.Call type MockSendAlgorithmWithDebugInfosTimeUntilSendCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendAlgorithmWithDebugInfosTimeUntilSendCall) Return(arg0 monotime.Time) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendAlgorithmWithDebugInfosTimeUntilSendCall) Do(f func(protocol.ByteCount) monotime.Time) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendAlgorithmWithDebugInfosTimeUntilSendCall) DoAndReturn(f func(protocol.ByteCount) monotime.Time) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/mocks/crypto_setup.go000066400000000000000000000617411513066070600212050ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: CryptoSetup) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mocks -destination crypto_setup.go github.com/quic-go/quic-go/internal/handshake CryptoSetup // // Package mocks is a generated GoMock package. package mocks import ( context "context" reflect "reflect" handshake "github.com/quic-go/quic-go/internal/handshake" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockCryptoSetup is a mock of CryptoSetup interface. type MockCryptoSetup struct { ctrl *gomock.Controller recorder *MockCryptoSetupMockRecorder isgomock struct{} } // MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup. type MockCryptoSetupMockRecorder struct { mock *MockCryptoSetup } // NewMockCryptoSetup creates a new mock instance. func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup { mock := &MockCryptoSetup{ctrl: ctrl} mock.recorder = &MockCryptoSetupMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder { return m.recorder } // ChangeConnectionID mocks base method. func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() m.ctrl.Call(m, "ChangeConnectionID", arg0) } // ChangeConnectionID indicates an expected call of ChangeConnectionID. func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 any) *MockCryptoSetupChangeConnectionIDCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) return &MockCryptoSetupChangeConnectionIDCall{Call: call} } // MockCryptoSetupChangeConnectionIDCall wrap *gomock.Call type MockCryptoSetupChangeConnectionIDCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupChangeConnectionIDCall) Return() *MockCryptoSetupChangeConnectionIDCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupChangeConnectionIDCall) Do(f func(protocol.ConnectionID)) *MockCryptoSetupChangeConnectionIDCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupChangeConnectionIDCall) DoAndReturn(f func(protocol.ConnectionID)) *MockCryptoSetupChangeConnectionIDCall { c.Call = c.Call.DoAndReturn(f) return c } // Close mocks base method. func (m *MockCryptoSetup) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. func (mr *MockCryptoSetupMockRecorder) Close() *MockCryptoSetupCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) return &MockCryptoSetupCloseCall{Call: call} } // MockCryptoSetupCloseCall wrap *gomock.Call type MockCryptoSetupCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupCloseCall) Return(arg0 error) *MockCryptoSetupCloseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupCloseCall) Do(f func() error) *MockCryptoSetupCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupCloseCall) DoAndReturn(f func() error) *MockCryptoSetupCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // ConnectionState mocks base method. func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") ret0, _ := ret[0].(handshake.ConnectionState) return ret0 } // ConnectionState indicates an expected call of ConnectionState. func (mr *MockCryptoSetupMockRecorder) ConnectionState() *MockCryptoSetupConnectionStateCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) return &MockCryptoSetupConnectionStateCall{Call: call} } // MockCryptoSetupConnectionStateCall wrap *gomock.Call type MockCryptoSetupConnectionStateCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupConnectionStateCall) Return(arg0 handshake.ConnectionState) *MockCryptoSetupConnectionStateCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupConnectionStateCall) Do(f func() handshake.ConnectionState) *MockCryptoSetupConnectionStateCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupConnectionStateCall) DoAndReturn(f func() handshake.ConnectionState) *MockCryptoSetupConnectionStateCall { c.Call = c.Call.DoAndReturn(f) return c } // DiscardInitialKeys mocks base method. func (m *MockCryptoSetup) DiscardInitialKeys() { m.ctrl.T.Helper() m.ctrl.Call(m, "DiscardInitialKeys") } // DiscardInitialKeys indicates an expected call of DiscardInitialKeys. func (mr *MockCryptoSetupMockRecorder) DiscardInitialKeys() *MockCryptoSetupDiscardInitialKeysCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscardInitialKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DiscardInitialKeys)) return &MockCryptoSetupDiscardInitialKeysCall{Call: call} } // MockCryptoSetupDiscardInitialKeysCall wrap *gomock.Call type MockCryptoSetupDiscardInitialKeysCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupDiscardInitialKeysCall) Return() *MockCryptoSetupDiscardInitialKeysCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupDiscardInitialKeysCall) Do(f func()) *MockCryptoSetupDiscardInitialKeysCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupDiscardInitialKeysCall) DoAndReturn(f func()) *MockCryptoSetupDiscardInitialKeysCall { c.Call = c.Call.DoAndReturn(f) return c } // Get0RTTOpener mocks base method. func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get0RTTOpener") ret0, _ := ret[0].(handshake.LongHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } // Get0RTTOpener indicates an expected call of Get0RTTOpener. func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *MockCryptoSetupGet0RTTOpenerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) return &MockCryptoSetupGet0RTTOpenerCall{Call: call} } // MockCryptoSetupGet0RTTOpenerCall wrap *gomock.Call type MockCryptoSetupGet0RTTOpenerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGet0RTTOpenerCall) Return(arg0 handshake.LongHeaderOpener, arg1 error) *MockCryptoSetupGet0RTTOpenerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGet0RTTOpenerCall) Do(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGet0RTTOpenerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGet0RTTOpenerCall) DoAndReturn(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGet0RTTOpenerCall { c.Call = c.Call.DoAndReturn(f) return c } // Get0RTTSealer mocks base method. func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get0RTTSealer") ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // Get0RTTSealer indicates an expected call of Get0RTTSealer. func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *MockCryptoSetupGet0RTTSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) return &MockCryptoSetupGet0RTTSealerCall{Call: call} } // MockCryptoSetupGet0RTTSealerCall wrap *gomock.Call type MockCryptoSetupGet0RTTSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGet0RTTSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockCryptoSetupGet0RTTSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGet0RTTSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGet0RTTSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGet0RTTSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGet0RTTSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // Get1RTTOpener mocks base method. func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get1RTTOpener") ret0, _ := ret[0].(handshake.ShortHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } // Get1RTTOpener indicates an expected call of Get1RTTOpener. func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *MockCryptoSetupGet1RTTOpenerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) return &MockCryptoSetupGet1RTTOpenerCall{Call: call} } // MockCryptoSetupGet1RTTOpenerCall wrap *gomock.Call type MockCryptoSetupGet1RTTOpenerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGet1RTTOpenerCall) Return(arg0 handshake.ShortHeaderOpener, arg1 error) *MockCryptoSetupGet1RTTOpenerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGet1RTTOpenerCall) Do(f func() (handshake.ShortHeaderOpener, error)) *MockCryptoSetupGet1RTTOpenerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGet1RTTOpenerCall) DoAndReturn(f func() (handshake.ShortHeaderOpener, error)) *MockCryptoSetupGet1RTTOpenerCall { c.Call = c.Call.DoAndReturn(f) return c } // Get1RTTSealer mocks base method. func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get1RTTSealer") ret0, _ := ret[0].(handshake.ShortHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // Get1RTTSealer indicates an expected call of Get1RTTSealer. func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *MockCryptoSetupGet1RTTSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) return &MockCryptoSetupGet1RTTSealerCall{Call: call} } // MockCryptoSetupGet1RTTSealerCall wrap *gomock.Call type MockCryptoSetupGet1RTTSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGet1RTTSealerCall) Return(arg0 handshake.ShortHeaderSealer, arg1 error) *MockCryptoSetupGet1RTTSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGet1RTTSealerCall) Do(f func() (handshake.ShortHeaderSealer, error)) *MockCryptoSetupGet1RTTSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGet1RTTSealerCall) DoAndReturn(f func() (handshake.ShortHeaderSealer, error)) *MockCryptoSetupGet1RTTSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetHandshakeOpener mocks base method. func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetHandshakeOpener") ret0, _ := ret[0].(handshake.LongHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHandshakeOpener indicates an expected call of GetHandshakeOpener. func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *MockCryptoSetupGetHandshakeOpenerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) return &MockCryptoSetupGetHandshakeOpenerCall{Call: call} } // MockCryptoSetupGetHandshakeOpenerCall wrap *gomock.Call type MockCryptoSetupGetHandshakeOpenerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGetHandshakeOpenerCall) Return(arg0 handshake.LongHeaderOpener, arg1 error) *MockCryptoSetupGetHandshakeOpenerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGetHandshakeOpenerCall) Do(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetHandshakeOpenerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGetHandshakeOpenerCall) DoAndReturn(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetHandshakeOpenerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetHandshakeSealer mocks base method. func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetHandshakeSealer") ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHandshakeSealer indicates an expected call of GetHandshakeSealer. func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *MockCryptoSetupGetHandshakeSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) return &MockCryptoSetupGetHandshakeSealerCall{Call: call} } // MockCryptoSetupGetHandshakeSealerCall wrap *gomock.Call type MockCryptoSetupGetHandshakeSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGetHandshakeSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockCryptoSetupGetHandshakeSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGetHandshakeSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetHandshakeSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGetHandshakeSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetHandshakeSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetInitialOpener mocks base method. func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInitialOpener") ret0, _ := ret[0].(handshake.LongHeaderOpener) ret1, _ := ret[1].(error) return ret0, ret1 } // GetInitialOpener indicates an expected call of GetInitialOpener. func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *MockCryptoSetupGetInitialOpenerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) return &MockCryptoSetupGetInitialOpenerCall{Call: call} } // MockCryptoSetupGetInitialOpenerCall wrap *gomock.Call type MockCryptoSetupGetInitialOpenerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGetInitialOpenerCall) Return(arg0 handshake.LongHeaderOpener, arg1 error) *MockCryptoSetupGetInitialOpenerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGetInitialOpenerCall) Do(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetInitialOpenerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGetInitialOpenerCall) DoAndReturn(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetInitialOpenerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetInitialSealer mocks base method. func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInitialSealer") ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // GetInitialSealer indicates an expected call of GetInitialSealer. func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *MockCryptoSetupGetInitialSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) return &MockCryptoSetupGetInitialSealerCall{Call: call} } // MockCryptoSetupGetInitialSealerCall wrap *gomock.Call type MockCryptoSetupGetInitialSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGetInitialSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockCryptoSetupGetInitialSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGetInitialSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetInitialSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGetInitialSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetInitialSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetSessionTicket mocks base method. func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSessionTicket") ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // GetSessionTicket indicates an expected call of GetSessionTicket. func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *MockCryptoSetupGetSessionTicketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) return &MockCryptoSetupGetSessionTicketCall{Call: call} } // MockCryptoSetupGetSessionTicketCall wrap *gomock.Call type MockCryptoSetupGetSessionTicketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupGetSessionTicketCall) Return(arg0 []byte, arg1 error) *MockCryptoSetupGetSessionTicketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupGetSessionTicketCall) Do(f func() ([]byte, error)) *MockCryptoSetupGetSessionTicketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupGetSessionTicketCall) DoAndReturn(f func() ([]byte, error)) *MockCryptoSetupGetSessionTicketCall { c.Call = c.Call.DoAndReturn(f) return c } // HandleMessage mocks base method. func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // HandleMessage indicates an expected call of HandleMessage. func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 any) *MockCryptoSetupHandleMessageCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) return &MockCryptoSetupHandleMessageCall{Call: call} } // MockCryptoSetupHandleMessageCall wrap *gomock.Call type MockCryptoSetupHandleMessageCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupHandleMessageCall) Return(arg0 error) *MockCryptoSetupHandleMessageCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupHandleMessageCall) Do(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoSetupHandleMessageCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupHandleMessageCall) DoAndReturn(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoSetupHandleMessageCall { c.Call = c.Call.DoAndReturn(f) return c } // NextEvent mocks base method. func (m *MockCryptoSetup) NextEvent() handshake.Event { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NextEvent") ret0, _ := ret[0].(handshake.Event) return ret0 } // NextEvent indicates an expected call of NextEvent. func (mr *MockCryptoSetupMockRecorder) NextEvent() *MockCryptoSetupNextEventCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoSetup)(nil).NextEvent)) return &MockCryptoSetupNextEventCall{Call: call} } // MockCryptoSetupNextEventCall wrap *gomock.Call type MockCryptoSetupNextEventCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupNextEventCall) Return(arg0 handshake.Event) *MockCryptoSetupNextEventCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupNextEventCall) Do(f func() handshake.Event) *MockCryptoSetupNextEventCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupNextEventCall) DoAndReturn(f func() handshake.Event) *MockCryptoSetupNextEventCall { c.Call = c.Call.DoAndReturn(f) return c } // SetHandshakeConfirmed mocks base method. func (m *MockCryptoSetup) SetHandshakeConfirmed() { m.ctrl.T.Helper() m.ctrl.Call(m, "SetHandshakeConfirmed") } // SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *MockCryptoSetupSetHandshakeConfirmedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) return &MockCryptoSetupSetHandshakeConfirmedCall{Call: call} } // MockCryptoSetupSetHandshakeConfirmedCall wrap *gomock.Call type MockCryptoSetupSetHandshakeConfirmedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupSetHandshakeConfirmedCall) Return() *MockCryptoSetupSetHandshakeConfirmedCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupSetHandshakeConfirmedCall) Do(f func()) *MockCryptoSetupSetHandshakeConfirmedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupSetHandshakeConfirmedCall) DoAndReturn(f func()) *MockCryptoSetupSetHandshakeConfirmedCall { c.Call = c.Call.DoAndReturn(f) return c } // SetLargest1RTTAcked mocks base method. func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) ret0, _ := ret[0].(error) return ret0 } // SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 any) *MockCryptoSetupSetLargest1RTTAckedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) return &MockCryptoSetupSetLargest1RTTAckedCall{Call: call} } // MockCryptoSetupSetLargest1RTTAckedCall wrap *gomock.Call type MockCryptoSetupSetLargest1RTTAckedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupSetLargest1RTTAckedCall) Return(arg0 error) *MockCryptoSetupSetLargest1RTTAckedCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupSetLargest1RTTAckedCall) Do(f func(protocol.PacketNumber) error) *MockCryptoSetupSetLargest1RTTAckedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupSetLargest1RTTAckedCall) DoAndReturn(f func(protocol.PacketNumber) error) *MockCryptoSetupSetLargest1RTTAckedCall { c.Call = c.Call.DoAndReturn(f) return c } // StartHandshake mocks base method. func (m *MockCryptoSetup) StartHandshake(arg0 context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StartHandshake", arg0) ret0, _ := ret[0].(error) return ret0 } // StartHandshake indicates an expected call of StartHandshake. func (mr *MockCryptoSetupMockRecorder) StartHandshake(arg0 any) *MockCryptoSetupStartHandshakeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).StartHandshake), arg0) return &MockCryptoSetupStartHandshakeCall{Call: call} } // MockCryptoSetupStartHandshakeCall wrap *gomock.Call type MockCryptoSetupStartHandshakeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockCryptoSetupStartHandshakeCall) Return(arg0 error) *MockCryptoSetupStartHandshakeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockCryptoSetupStartHandshakeCall) Do(f func(context.Context) error) *MockCryptoSetupStartHandshakeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockCryptoSetupStartHandshakeCall) DoAndReturn(f func(context.Context) error) *MockCryptoSetupStartHandshakeCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/mocks/long_header_opener.go000066400000000000000000000132161513066070600222560ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: LongHeaderOpener) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener // // Package mocks is a generated GoMock package. package mocks import ( reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockLongHeaderOpener is a mock of LongHeaderOpener interface. type MockLongHeaderOpener struct { ctrl *gomock.Controller recorder *MockLongHeaderOpenerMockRecorder isgomock struct{} } // MockLongHeaderOpenerMockRecorder is the mock recorder for MockLongHeaderOpener. type MockLongHeaderOpenerMockRecorder struct { mock *MockLongHeaderOpener } // NewMockLongHeaderOpener creates a new mock instance. func NewMockLongHeaderOpener(ctrl *gomock.Controller) *MockLongHeaderOpener { mock := &MockLongHeaderOpener{ctrl: ctrl} mock.recorder = &MockLongHeaderOpenerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockLongHeaderOpener) EXPECT() *MockLongHeaderOpenerMockRecorder { return m.recorder } // DecodePacketNumber mocks base method. func (m *MockLongHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DecodePacketNumber", wirePN, wirePNLen) ret0, _ := ret[0].(protocol.PacketNumber) return ret0 } // DecodePacketNumber indicates an expected call of DecodePacketNumber. func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(wirePN, wirePNLen any) *MockLongHeaderOpenerDecodePacketNumberCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), wirePN, wirePNLen) return &MockLongHeaderOpenerDecodePacketNumberCall{Call: call} } // MockLongHeaderOpenerDecodePacketNumberCall wrap *gomock.Call type MockLongHeaderOpenerDecodePacketNumberCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockLongHeaderOpenerDecodePacketNumberCall) Return(arg0 protocol.PacketNumber) *MockLongHeaderOpenerDecodePacketNumberCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockLongHeaderOpenerDecodePacketNumberCall) Do(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockLongHeaderOpenerDecodePacketNumberCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockLongHeaderOpenerDecodePacketNumberCall) DoAndReturn(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockLongHeaderOpenerDecodePacketNumberCall { c.Call = c.Call.DoAndReturn(f) return c } // DecryptHeader mocks base method. func (m *MockLongHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { m.ctrl.T.Helper() m.ctrl.Call(m, "DecryptHeader", sample, firstByte, pnBytes) } // DecryptHeader indicates an expected call of DecryptHeader. func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(sample, firstByte, pnBytes any) *MockLongHeaderOpenerDecryptHeaderCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), sample, firstByte, pnBytes) return &MockLongHeaderOpenerDecryptHeaderCall{Call: call} } // MockLongHeaderOpenerDecryptHeaderCall wrap *gomock.Call type MockLongHeaderOpenerDecryptHeaderCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockLongHeaderOpenerDecryptHeaderCall) Return() *MockLongHeaderOpenerDecryptHeaderCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockLongHeaderOpenerDecryptHeaderCall) Do(f func([]byte, *byte, []byte)) *MockLongHeaderOpenerDecryptHeaderCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockLongHeaderOpenerDecryptHeaderCall) DoAndReturn(f func([]byte, *byte, []byte)) *MockLongHeaderOpenerDecryptHeaderCall { c.Call = c.Call.DoAndReturn(f) return c } // Open mocks base method. func (m *MockLongHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Open", dst, src, pn, associatedData) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // Open indicates an expected call of Open. func (mr *MockLongHeaderOpenerMockRecorder) Open(dst, src, pn, associatedData any) *MockLongHeaderOpenerOpenCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), dst, src, pn, associatedData) return &MockLongHeaderOpenerOpenCall{Call: call} } // MockLongHeaderOpenerOpenCall wrap *gomock.Call type MockLongHeaderOpenerOpenCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockLongHeaderOpenerOpenCall) Return(arg0 []byte, arg1 error) *MockLongHeaderOpenerOpenCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockLongHeaderOpenerOpenCall) Do(f func([]byte, []byte, protocol.PacketNumber, []byte) ([]byte, error)) *MockLongHeaderOpenerOpenCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockLongHeaderOpenerOpenCall) DoAndReturn(f func([]byte, []byte, protocol.PacketNumber, []byte) ([]byte, error)) *MockLongHeaderOpenerOpenCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/mocks/mockgen.go000066400000000000000000000025601513066070600200620ustar00rootroot00000000000000//go:build gomock || generate package mocks //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener" //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener" //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination crypto_setup.go github.com/quic-go/quic-go/internal/handshake CryptoSetup" //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos" //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler" quic-go-0.59.0/internal/mocks/short_header_opener.go000066400000000000000000000136411513066070600224600ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: ShortHeaderOpener) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener // // Package mocks is a generated GoMock package. package mocks import ( reflect "reflect" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockShortHeaderOpener is a mock of ShortHeaderOpener interface. type MockShortHeaderOpener struct { ctrl *gomock.Controller recorder *MockShortHeaderOpenerMockRecorder isgomock struct{} } // MockShortHeaderOpenerMockRecorder is the mock recorder for MockShortHeaderOpener. type MockShortHeaderOpenerMockRecorder struct { mock *MockShortHeaderOpener } // NewMockShortHeaderOpener creates a new mock instance. func NewMockShortHeaderOpener(ctrl *gomock.Controller) *MockShortHeaderOpener { mock := &MockShortHeaderOpener{ctrl: ctrl} mock.recorder = &MockShortHeaderOpenerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockShortHeaderOpener) EXPECT() *MockShortHeaderOpenerMockRecorder { return m.recorder } // DecodePacketNumber mocks base method. func (m *MockShortHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DecodePacketNumber", wirePN, wirePNLen) ret0, _ := ret[0].(protocol.PacketNumber) return ret0 } // DecodePacketNumber indicates an expected call of DecodePacketNumber. func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(wirePN, wirePNLen any) *MockShortHeaderOpenerDecodePacketNumberCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), wirePN, wirePNLen) return &MockShortHeaderOpenerDecodePacketNumberCall{Call: call} } // MockShortHeaderOpenerDecodePacketNumberCall wrap *gomock.Call type MockShortHeaderOpenerDecodePacketNumberCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderOpenerDecodePacketNumberCall) Return(arg0 protocol.PacketNumber) *MockShortHeaderOpenerDecodePacketNumberCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderOpenerDecodePacketNumberCall) Do(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockShortHeaderOpenerDecodePacketNumberCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderOpenerDecodePacketNumberCall) DoAndReturn(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockShortHeaderOpenerDecodePacketNumberCall { c.Call = c.Call.DoAndReturn(f) return c } // DecryptHeader mocks base method. func (m *MockShortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { m.ctrl.T.Helper() m.ctrl.Call(m, "DecryptHeader", sample, firstByte, pnBytes) } // DecryptHeader indicates an expected call of DecryptHeader. func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(sample, firstByte, pnBytes any) *MockShortHeaderOpenerDecryptHeaderCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), sample, firstByte, pnBytes) return &MockShortHeaderOpenerDecryptHeaderCall{Call: call} } // MockShortHeaderOpenerDecryptHeaderCall wrap *gomock.Call type MockShortHeaderOpenerDecryptHeaderCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderOpenerDecryptHeaderCall) Return() *MockShortHeaderOpenerDecryptHeaderCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderOpenerDecryptHeaderCall) Do(f func([]byte, *byte, []byte)) *MockShortHeaderOpenerDecryptHeaderCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderOpenerDecryptHeaderCall) DoAndReturn(f func([]byte, *byte, []byte)) *MockShortHeaderOpenerDecryptHeaderCall { c.Call = c.Call.DoAndReturn(f) return c } // Open mocks base method. func (m *MockShortHeaderOpener) Open(dst, src []byte, rcvTime monotime.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Open", dst, src, rcvTime, pn, kp, associatedData) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // Open indicates an expected call of Open. func (mr *MockShortHeaderOpenerMockRecorder) Open(dst, src, rcvTime, pn, kp, associatedData any) *MockShortHeaderOpenerOpenCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), dst, src, rcvTime, pn, kp, associatedData) return &MockShortHeaderOpenerOpenCall{Call: call} } // MockShortHeaderOpenerOpenCall wrap *gomock.Call type MockShortHeaderOpenerOpenCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderOpenerOpenCall) Return(arg0 []byte, arg1 error) *MockShortHeaderOpenerOpenCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderOpenerOpenCall) Do(f func([]byte, []byte, monotime.Time, protocol.PacketNumber, protocol.KeyPhaseBit, []byte) ([]byte, error)) *MockShortHeaderOpenerOpenCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderOpenerOpenCall) DoAndReturn(f func([]byte, []byte, monotime.Time, protocol.PacketNumber, protocol.KeyPhaseBit, []byte) ([]byte, error)) *MockShortHeaderOpenerOpenCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/mocks/short_header_sealer.go000066400000000000000000000146651513066070600224520ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/handshake (interfaces: ShortHeaderSealer) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer // // Package mocks is a generated GoMock package. package mocks import ( reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockShortHeaderSealer is a mock of ShortHeaderSealer interface. type MockShortHeaderSealer struct { ctrl *gomock.Controller recorder *MockShortHeaderSealerMockRecorder isgomock struct{} } // MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer. type MockShortHeaderSealerMockRecorder struct { mock *MockShortHeaderSealer } // NewMockShortHeaderSealer creates a new mock instance. func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer { mock := &MockShortHeaderSealer{ctrl: ctrl} mock.recorder = &MockShortHeaderSealerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder { return m.recorder } // EncryptHeader mocks base method. func (m *MockShortHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { m.ctrl.T.Helper() m.ctrl.Call(m, "EncryptHeader", sample, firstByte, pnBytes) } // EncryptHeader indicates an expected call of EncryptHeader. func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(sample, firstByte, pnBytes any) *MockShortHeaderSealerEncryptHeaderCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), sample, firstByte, pnBytes) return &MockShortHeaderSealerEncryptHeaderCall{Call: call} } // MockShortHeaderSealerEncryptHeaderCall wrap *gomock.Call type MockShortHeaderSealerEncryptHeaderCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderSealerEncryptHeaderCall) Return() *MockShortHeaderSealerEncryptHeaderCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderSealerEncryptHeaderCall) Do(f func([]byte, *byte, []byte)) *MockShortHeaderSealerEncryptHeaderCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderSealerEncryptHeaderCall) DoAndReturn(f func([]byte, *byte, []byte)) *MockShortHeaderSealerEncryptHeaderCall { c.Call = c.Call.DoAndReturn(f) return c } // KeyPhase mocks base method. func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeyPhase") ret0, _ := ret[0].(protocol.KeyPhaseBit) return ret0 } // KeyPhase indicates an expected call of KeyPhase. func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *MockShortHeaderSealerKeyPhaseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) return &MockShortHeaderSealerKeyPhaseCall{Call: call} } // MockShortHeaderSealerKeyPhaseCall wrap *gomock.Call type MockShortHeaderSealerKeyPhaseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderSealerKeyPhaseCall) Return(arg0 protocol.KeyPhaseBit) *MockShortHeaderSealerKeyPhaseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderSealerKeyPhaseCall) Do(f func() protocol.KeyPhaseBit) *MockShortHeaderSealerKeyPhaseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderSealerKeyPhaseCall) DoAndReturn(f func() protocol.KeyPhaseBit) *MockShortHeaderSealerKeyPhaseCall { c.Call = c.Call.DoAndReturn(f) return c } // Overhead mocks base method. func (m *MockShortHeaderSealer) Overhead() int { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Overhead") ret0, _ := ret[0].(int) return ret0 } // Overhead indicates an expected call of Overhead. func (mr *MockShortHeaderSealerMockRecorder) Overhead() *MockShortHeaderSealerOverheadCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) return &MockShortHeaderSealerOverheadCall{Call: call} } // MockShortHeaderSealerOverheadCall wrap *gomock.Call type MockShortHeaderSealerOverheadCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderSealerOverheadCall) Return(arg0 int) *MockShortHeaderSealerOverheadCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderSealerOverheadCall) Do(f func() int) *MockShortHeaderSealerOverheadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderSealerOverheadCall) DoAndReturn(f func() int) *MockShortHeaderSealerOverheadCall { c.Call = c.Call.DoAndReturn(f) return c } // Seal mocks base method. func (m *MockShortHeaderSealer) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Seal", dst, src, packetNumber, associatedData) ret0, _ := ret[0].([]byte) return ret0 } // Seal indicates an expected call of Seal. func (mr *MockShortHeaderSealerMockRecorder) Seal(dst, src, packetNumber, associatedData any) *MockShortHeaderSealerSealCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), dst, src, packetNumber, associatedData) return &MockShortHeaderSealerSealCall{Call: call} } // MockShortHeaderSealerSealCall wrap *gomock.Call type MockShortHeaderSealerSealCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockShortHeaderSealerSealCall) Return(arg0 []byte) *MockShortHeaderSealerSealCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockShortHeaderSealerSealCall) Do(f func([]byte, []byte, protocol.PacketNumber, []byte) []byte) *MockShortHeaderSealerSealCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockShortHeaderSealerSealCall) DoAndReturn(f func([]byte, []byte, protocol.PacketNumber, []byte) []byte) *MockShortHeaderSealerSealCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/mocks/stream_flow_controller.go000066400000000000000000000307511513066070600232270ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go/internal/flowcontrol (interfaces: StreamFlowController) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController // // Package mocks is a generated GoMock package. package mocks import ( reflect "reflect" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockStreamFlowController is a mock of StreamFlowController interface. type MockStreamFlowController struct { ctrl *gomock.Controller recorder *MockStreamFlowControllerMockRecorder isgomock struct{} } // MockStreamFlowControllerMockRecorder is the mock recorder for MockStreamFlowController. type MockStreamFlowControllerMockRecorder struct { mock *MockStreamFlowController } // NewMockStreamFlowController creates a new mock instance. func NewMockStreamFlowController(ctrl *gomock.Controller) *MockStreamFlowController { mock := &MockStreamFlowController{ctrl: ctrl} mock.recorder = &MockStreamFlowControllerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockStreamFlowController) EXPECT() *MockStreamFlowControllerMockRecorder { return m.recorder } // Abandon mocks base method. func (m *MockStreamFlowController) Abandon() { m.ctrl.T.Helper() m.ctrl.Call(m, "Abandon") } // Abandon indicates an expected call of Abandon. func (mr *MockStreamFlowControllerMockRecorder) Abandon() *MockStreamFlowControllerAbandonCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) return &MockStreamFlowControllerAbandonCall{Call: call} } // MockStreamFlowControllerAbandonCall wrap *gomock.Call type MockStreamFlowControllerAbandonCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerAbandonCall) Return() *MockStreamFlowControllerAbandonCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerAbandonCall) Do(f func()) *MockStreamFlowControllerAbandonCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerAbandonCall) DoAndReturn(f func()) *MockStreamFlowControllerAbandonCall { c.Call = c.Call.DoAndReturn(f) return c } // AddBytesRead mocks base method. func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) (bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddBytesRead", arg0) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(bool) return ret0, ret1 } // AddBytesRead indicates an expected call of AddBytesRead. func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 any) *MockStreamFlowControllerAddBytesReadCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) return &MockStreamFlowControllerAddBytesReadCall{Call: call} } // MockStreamFlowControllerAddBytesReadCall wrap *gomock.Call type MockStreamFlowControllerAddBytesReadCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerAddBytesReadCall) Return(hasStreamWindowUpdate, hasConnWindowUpdate bool) *MockStreamFlowControllerAddBytesReadCall { c.Call = c.Call.Return(hasStreamWindowUpdate, hasConnWindowUpdate) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount) (bool, bool)) *MockStreamFlowControllerAddBytesReadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount) (bool, bool)) *MockStreamFlowControllerAddBytesReadCall { c.Call = c.Call.DoAndReturn(f) return c } // AddBytesSent mocks base method. func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "AddBytesSent", arg0) } // AddBytesSent indicates an expected call of AddBytesSent. func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 any) *MockStreamFlowControllerAddBytesSentCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) return &MockStreamFlowControllerAddBytesSentCall{Call: call} } // MockStreamFlowControllerAddBytesSentCall wrap *gomock.Call type MockStreamFlowControllerAddBytesSentCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerAddBytesSentCall) Return() *MockStreamFlowControllerAddBytesSentCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerAddBytesSentCall) Do(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesSentCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerAddBytesSentCall) DoAndReturn(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesSentCall { c.Call = c.Call.DoAndReturn(f) return c } // GetWindowUpdate mocks base method. func (m *MockStreamFlowController) GetWindowUpdate(arg0 monotime.Time) protocol.ByteCount { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetWindowUpdate", arg0) ret0, _ := ret[0].(protocol.ByteCount) return ret0 } // GetWindowUpdate indicates an expected call of GetWindowUpdate. func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate(arg0 any) *MockStreamFlowControllerGetWindowUpdateCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate), arg0) return &MockStreamFlowControllerGetWindowUpdateCall{Call: call} } // MockStreamFlowControllerGetWindowUpdateCall wrap *gomock.Call type MockStreamFlowControllerGetWindowUpdateCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerGetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockStreamFlowControllerGetWindowUpdateCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerGetWindowUpdateCall) Do(f func(monotime.Time) protocol.ByteCount) *MockStreamFlowControllerGetWindowUpdateCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerGetWindowUpdateCall) DoAndReturn(f func(monotime.Time) protocol.ByteCount) *MockStreamFlowControllerGetWindowUpdateCall { c.Call = c.Call.DoAndReturn(f) return c } // IsNewlyBlocked mocks base method. func (m *MockStreamFlowController) IsNewlyBlocked() bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IsNewlyBlocked") ret0, _ := ret[0].(bool) return ret0 } // IsNewlyBlocked indicates an expected call of IsNewlyBlocked. func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *MockStreamFlowControllerIsNewlyBlockedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) return &MockStreamFlowControllerIsNewlyBlockedCall{Call: call} } // MockStreamFlowControllerIsNewlyBlockedCall wrap *gomock.Call type MockStreamFlowControllerIsNewlyBlockedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerIsNewlyBlockedCall) Return(arg0 bool) *MockStreamFlowControllerIsNewlyBlockedCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerIsNewlyBlockedCall) Do(f func() bool) *MockStreamFlowControllerIsNewlyBlockedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerIsNewlyBlockedCall) DoAndReturn(f func() bool) *MockStreamFlowControllerIsNewlyBlockedCall { c.Call = c.Call.DoAndReturn(f) return c } // SendWindowSize mocks base method. func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendWindowSize") ret0, _ := ret[0].(protocol.ByteCount) return ret0 } // SendWindowSize indicates an expected call of SendWindowSize. func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *MockStreamFlowControllerSendWindowSizeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) return &MockStreamFlowControllerSendWindowSizeCall{Call: call} } // MockStreamFlowControllerSendWindowSizeCall wrap *gomock.Call type MockStreamFlowControllerSendWindowSizeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerSendWindowSizeCall) Return(arg0 protocol.ByteCount) *MockStreamFlowControllerSendWindowSizeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerSendWindowSizeCall) Do(f func() protocol.ByteCount) *MockStreamFlowControllerSendWindowSizeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerSendWindowSizeCall) DoAndReturn(f func() protocol.ByteCount) *MockStreamFlowControllerSendWindowSizeCall { c.Call = c.Call.DoAndReturn(f) return c } // UpdateHighestReceived mocks base method. func (m *MockStreamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now monotime.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateHighestReceived", offset, final, now) ret0, _ := ret[0].(error) return ret0 } // UpdateHighestReceived indicates an expected call of UpdateHighestReceived. func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(offset, final, now any) *MockStreamFlowControllerUpdateHighestReceivedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), offset, final, now) return &MockStreamFlowControllerUpdateHighestReceivedCall{Call: call} } // MockStreamFlowControllerUpdateHighestReceivedCall wrap *gomock.Call type MockStreamFlowControllerUpdateHighestReceivedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerUpdateHighestReceivedCall) Return(arg0 error) *MockStreamFlowControllerUpdateHighestReceivedCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerUpdateHighestReceivedCall) Do(f func(protocol.ByteCount, bool, monotime.Time) error) *MockStreamFlowControllerUpdateHighestReceivedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerUpdateHighestReceivedCall) DoAndReturn(f func(protocol.ByteCount, bool, monotime.Time) error) *MockStreamFlowControllerUpdateHighestReceivedCall { c.Call = c.Call.DoAndReturn(f) return c } // UpdateSendWindow mocks base method. func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateSendWindow", arg0) ret0, _ := ret[0].(bool) return ret0 } // UpdateSendWindow indicates an expected call of UpdateSendWindow. func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *MockStreamFlowControllerUpdateSendWindowCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) return &MockStreamFlowControllerUpdateSendWindowCall{Call: call} } // MockStreamFlowControllerUpdateSendWindowCall wrap *gomock.Call type MockStreamFlowControllerUpdateSendWindowCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFlowControllerUpdateSendWindowCall) Return(updated bool) *MockStreamFlowControllerUpdateSendWindowCall { c.Call = c.Call.Return(updated) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFlowControllerUpdateSendWindowCall) Do(f func(protocol.ByteCount) bool) *MockStreamFlowControllerUpdateSendWindowCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFlowControllerUpdateSendWindowCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockStreamFlowControllerUpdateSendWindowCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/internal/monotime/000077500000000000000000000000001513066070600166205ustar00rootroot00000000000000quic-go-0.59.0/internal/monotime/time.go000066400000000000000000000052261513066070600201120ustar00rootroot00000000000000// Package monotime provides a monotonic time representation that is useful for // measuring elapsed time. // It is designed as a memory optimized drop-in replacement for time.Time, with // a monotime.Time consuming just 8 bytes instead of 24 bytes. package monotime import ( "time" ) // The absolute value doesn't matter, but it should be in the past, // so that every timestamp obtained with Now() is non-zero, // even on systems with low timer resolutions (e.g. Windows). var start = time.Now().Add(-time.Hour) // A Time represents an instant in monotonic time. // Times can be compared using the comparison operators, but the specific // value is implementation-dependent and should not be relied upon. // The zero value of Time doesn't have any specific meaning. type Time int64 // Now returns the current monotonic time. func Now() Time { return Time(time.Since(start).Nanoseconds()) } // Sub returns the duration t-t2. If the result exceeds the maximum (or minimum) // value that can be stored in a Duration, the maximum (or minimum) duration // will be returned. // To compute t-d for a duration d, use t.Add(-d). func (t Time) Sub(t2 Time) time.Duration { return time.Duration(t - t2) } // Add returns the time t+d. func (t Time) Add(d time.Duration) Time { return Time(int64(t) + d.Nanoseconds()) } // After reports whether the time instant t is after t2. func (t Time) After(t2 Time) bool { return t > t2 } // Before reports whether the time instant t is before t2. func (t Time) Before(t2 Time) bool { return t < t2 } // IsZero reports whether t represents the zero time instant. func (t Time) IsZero() bool { return t == 0 } // Equal reports whether t and t2 represent the same time instant. func (t Time) Equal(t2 Time) bool { return t == t2 } // ToTime converts the monotonic time to a time.Time value. // The returned time.Time will have the same instant as the monotonic time, // but may be subject to clock adjustments. func (t Time) ToTime() time.Time { if t.IsZero() { return time.Time{} } return start.Add(time.Duration(t)) } // Since returns the time elapsed since t. It is shorthand for Now().Sub(t). func Since(t Time) time.Duration { return Now().Sub(t) } // Until returns the duration until t. // It is shorthand for t.Sub(Now()). // If t is in the past, the returned duration will be negative. func Until(t Time) time.Duration { return time.Duration(t - Now()) } // FromTime converts a time.Time to a monotonic Time. // The conversion is relative to the package's start time and may lose // precision if the time.Time is far from the start time. func FromTime(t time.Time) Time { if t.IsZero() { return 0 } return Time(t.Sub(start).Nanoseconds()) } quic-go-0.59.0/internal/monotime/time_test.go000066400000000000000000000032531513066070600211470ustar00rootroot00000000000000package monotime import ( "testing" "time" "github.com/quic-go/quic-go/internal/synctest" "github.com/stretchr/testify/require" ) func TestTimeRelations(t *testing.T) { t1 := Now() require.Equal(t, t1, t1) require.False(t, t1.IsZero()) t2 := t1.Add(time.Second) require.False(t, t1.Equal(t2)) require.False(t, t2.Equal(t1)) require.True(t, t2.After(t1)) require.False(t, t1.After(t2)) require.False(t, t2.Before(t1)) require.Equal(t, t2.Sub(t1), time.Second) require.Equal(t, t1.Sub(t2), -time.Second) } func TestSince(t *testing.T) { synctest.Test(t, func(t *testing.T) { t1 := Now() time.Sleep(time.Second) require.Equal(t, Since(t1), time.Second) require.Equal(t, Now().Sub(t1), time.Second) time.Sleep(time.Minute) require.Equal(t, Since(t1), time.Minute+time.Second) require.Equal(t, Now().Sub(t1), time.Minute+time.Second) }) } func TestUntil(t *testing.T) { synctest.Test(t, func(t *testing.T) { t1 := Now().Add(time.Minute) require.Equal(t, Until(t1), time.Minute) require.Equal(t, t1.Sub(Now()), time.Minute) time.Sleep(15 * time.Second) require.Equal(t, Until(t1), 45*time.Second) require.Equal(t, t1.Sub(Now()), 45*time.Second) }) } func TestConversions(t *testing.T) { t1 := Now() t1Time := t1.ToTime() require.Equal(t, FromTime(t1Time), t1) require.Zero(t, t1Time.Sub(t1.ToTime())) var zeroTime time.Time require.Zero(t, FromTime(zeroTime)) require.Zero(t, FromTime(zeroTime)) var zero Time require.True(t, zero.ToTime().IsZero()) } func BenchmarkNow(b *testing.B) { b.Run("Now", func(b *testing.B) { for b.Loop() { _ = Now() } }) b.Run("time.Now", func(b *testing.B) { for b.Loop() { _ = time.Now() } }) } quic-go-0.59.0/internal/protocol/000077500000000000000000000000001513066070600166325ustar00rootroot00000000000000quic-go-0.59.0/internal/protocol/connection_id.go000066400000000000000000000053511513066070600220000ustar00rootroot00000000000000package protocol import ( "crypto/rand" "encoding/hex" "errors" "io" ) var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length") // An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999. // Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1 // restricts the length to 20 bytes. type ArbitraryLenConnectionID []byte func (c ArbitraryLenConnectionID) Len() int { return len(c) } func (c ArbitraryLenConnectionID) Bytes() []byte { return c } func (c ArbitraryLenConnectionID) String() string { if c.Len() == 0 { return "(empty)" } return hex.EncodeToString(c.Bytes()) } const maxConnectionIDLen = 20 // A ConnectionID in QUIC type ConnectionID struct { b [20]byte l uint8 } // GenerateConnectionID generates a connection ID using cryptographic random func GenerateConnectionID(l int) (ConnectionID, error) { var c ConnectionID c.l = uint8(l) _, err := rand.Read(c.b[:l]) return c, err } // ParseConnectionID interprets b as a Connection ID. // It panics if b is longer than 20 bytes. func ParseConnectionID(b []byte) ConnectionID { if len(b) > maxConnectionIDLen { panic("invalid conn id length") } var c ConnectionID c.l = uint8(len(b)) copy(c.b[:c.l], b) return c } // GenerateConnectionIDForInitial generates a connection ID for the Initial packet. // It uses a length randomly chosen between 8 and 20 bytes. func GenerateConnectionIDForInitial() (ConnectionID, error) { r := make([]byte, 1) if _, err := rand.Read(r); err != nil { return ConnectionID{}, err } l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) return GenerateConnectionID(l) } // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { var c ConnectionID if l == 0 { return c, nil } if l > maxConnectionIDLen { return c, ErrInvalidConnectionIDLen } c.l = uint8(l) _, err := io.ReadFull(r, c.b[:l]) if err == io.ErrUnexpectedEOF { return c, io.EOF } return c, err } // Len returns the length of the connection ID in bytes func (c ConnectionID) Len() int { return int(c.l) } // Bytes returns the byte representation func (c ConnectionID) Bytes() []byte { return c.b[:c.l] } func (c ConnectionID) String() string { if c.Len() == 0 { return "(empty)" } return hex.EncodeToString(c.Bytes()) } type DefaultConnectionIDGenerator struct { ConnLen int } func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) { return GenerateConnectionID(d.ConnLen) } func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int { return d.ConnLen } quic-go-0.59.0/internal/protocol/connection_id_test.go000066400000000000000000000045661513066070600230460ustar00rootroot00000000000000package protocol import ( "bytes" "crypto/rand" "io" "testing" "github.com/stretchr/testify/require" ) func TestGenerateRandomConnectionIDs(t *testing.T) { c1, err := GenerateConnectionID(8) require.NoError(t, err) require.NotZero(t, c1) require.Equal(t, 8, c1.Len()) c2, err := GenerateConnectionID(8) require.NoError(t, err) require.NotEqual(t, c1, c2) require.Equal(t, 8, c2.Len()) } func TestGenerateRandomLengthDestinationConnectionIDs(t *testing.T) { var has8ByteConnID, has20ByteConnID bool for i := 0; i < 1000; i++ { c, err := GenerateConnectionIDForInitial() require.NoError(t, err) require.GreaterOrEqual(t, c.Len(), 8) require.LessOrEqual(t, c.Len(), 20) if c.Len() == 8 { has8ByteConnID = true } if c.Len() == 20 { has20ByteConnID = true } } require.True(t, has8ByteConnID) require.True(t, has20ByteConnID) } func TestConnectionID(t *testing.T) { buf := bytes.NewBuffer([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) c, err := ReadConnectionID(buf, 5) require.NoError(t, err) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef, 0x42}, c.Bytes()) require.Equal(t, 5, c.Len()) require.Equal(t, "deadbeef42", c.String()) // too few bytes _, err = ReadConnectionID(buf, 10) require.Equal(t, io.EOF, err) // zero length c2, err := ReadConnectionID(buf, 0) require.NoError(t, err) require.Zero(t, c2.Len()) // connection ID can have a length of a maximum of 20 bytes buf2 := bytes.NewBuffer(make([]byte, 21)) _, err = ReadConnectionID(buf2, 21) require.Equal(t, ErrInvalidConnectionIDLen, err) } func TestConnectionIDZeroValue(t *testing.T) { var c ConnectionID require.Zero(t, c.Len()) require.Empty(t, c.Bytes()) require.Equal(t, "(empty)", (ConnectionID{}).String()) } // The string representation of a connection ID is used in qlog, so it should be fast. func BenchmarkConnectionIDStringer(b *testing.B) { c := ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) b.ReportAllocs() for b.Loop() { _ = c.String() } } func TestArbitraryLenConnectionID(t *testing.T) { b := make([]byte, 42) rand.Read(b) c := ArbitraryLenConnectionID(b) require.Equal(t, b, c.Bytes()) require.Equal(t, 42, c.Len()) } func TestArbitraryLenConnectionIDStringer(t *testing.T) { require.Equal(t, "(empty)", (ArbitraryLenConnectionID{}).String()) c := ArbitraryLenConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) require.Equal(t, "deadbeef42", c.String()) } quic-go-0.59.0/internal/protocol/encryption_level.go000066400000000000000000000030311513066070600225370ustar00rootroot00000000000000package protocol import ( "crypto/tls" "fmt" ) // EncryptionLevel is the encryption level // Default value is Unencrypted type EncryptionLevel uint8 const ( // EncryptionInitial is the Initial encryption level EncryptionInitial EncryptionLevel = 1 + iota // EncryptionHandshake is the Handshake encryption level EncryptionHandshake // Encryption0RTT is the 0-RTT encryption level Encryption0RTT // Encryption1RTT is the 1-RTT encryption level Encryption1RTT ) func (e EncryptionLevel) String() string { switch e { case EncryptionInitial: return "Initial" case EncryptionHandshake: return "Handshake" case Encryption0RTT: return "0-RTT" case Encryption1RTT: return "1-RTT" } return "unknown" } func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel { switch e { case EncryptionInitial: return tls.QUICEncryptionLevelInitial case EncryptionHandshake: return tls.QUICEncryptionLevelHandshake case Encryption1RTT: return tls.QUICEncryptionLevelApplication case Encryption0RTT: return tls.QUICEncryptionLevelEarly default: panic(fmt.Sprintf("unexpected encryption level: %s", e)) } } func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel { switch e { case tls.QUICEncryptionLevelInitial: return EncryptionInitial case tls.QUICEncryptionLevelHandshake: return EncryptionHandshake case tls.QUICEncryptionLevelApplication: return Encryption1RTT case tls.QUICEncryptionLevelEarly: return Encryption0RTT default: panic(fmt.Sprintf("unexpect encryption level: %s", e)) } } quic-go-0.59.0/internal/protocol/encryption_level_test.go000066400000000000000000000023171513066070600236040ustar00rootroot00000000000000package protocol import ( "crypto/tls" "testing" "github.com/stretchr/testify/require" ) func TestEncryptionLevelNonZeroValue(t *testing.T) { require.NotZero(t, EncryptionInitial*EncryptionHandshake*Encryption0RTT*Encryption1RTT) } func TestEncryptionLevelConversion(t *testing.T) { testCases := []struct { quicLevel EncryptionLevel tlsLevel tls.QUICEncryptionLevel }{ {EncryptionInitial, tls.QUICEncryptionLevelInitial}, {EncryptionHandshake, tls.QUICEncryptionLevelHandshake}, {Encryption1RTT, tls.QUICEncryptionLevelApplication}, {Encryption0RTT, tls.QUICEncryptionLevelEarly}, } for _, tc := range testCases { t.Run(tc.quicLevel.String(), func(t *testing.T) { // conversion from QUIC to TLS encryption level require.Equal(t, tc.tlsLevel, tc.quicLevel.ToTLSEncryptionLevel()) // conversion from TLS to QUIC encryption level require.Equal(t, tc.quicLevel, FromTLSEncryptionLevel(tc.tlsLevel)) }) } } func TestEncryptionLevelStringRepresentation(t *testing.T) { require.Equal(t, "Initial", EncryptionInitial.String()) require.Equal(t, "Handshake", EncryptionHandshake.String()) require.Equal(t, "0-RTT", Encryption0RTT.String()) require.Equal(t, "1-RTT", Encryption1RTT.String()) } quic-go-0.59.0/internal/protocol/key_phase.go000066400000000000000000000011621513066070600211310ustar00rootroot00000000000000package protocol // KeyPhase is the key phase type KeyPhase uint64 // Bit determines the key phase bit func (p KeyPhase) Bit() KeyPhaseBit { if p%2 == 0 { return KeyPhaseZero } return KeyPhaseOne } // KeyPhaseBit is the key phase bit type KeyPhaseBit uint8 const ( // KeyPhaseUndefined is an undefined key phase KeyPhaseUndefined KeyPhaseBit = iota // KeyPhaseZero is key phase 0 KeyPhaseZero // KeyPhaseOne is key phase 1 KeyPhaseOne ) func (p KeyPhaseBit) String() string { //nolint:exhaustive switch p { case KeyPhaseZero: return "0" case KeyPhaseOne: return "1" default: return "undefined" } } quic-go-0.59.0/internal/protocol/key_phase_test.go000066400000000000000000000012541513066070600221720ustar00rootroot00000000000000package protocol import ( "testing" "github.com/stretchr/testify/require" ) func TestKeyPhaseBitDefaultValue(t *testing.T) { var k KeyPhaseBit require.Equal(t, KeyPhaseUndefined, k) } func TestKeyPhaseStringRepresentation(t *testing.T) { require.Equal(t, "0", KeyPhaseZero.String()) require.Equal(t, "1", KeyPhaseOne.String()) } func TestKeyPhaseToBit(t *testing.T) { require.Equal(t, KeyPhaseZero, KeyPhase(0).Bit()) require.Equal(t, KeyPhaseZero, KeyPhase(2).Bit()) require.Equal(t, KeyPhaseZero, KeyPhase(4).Bit()) require.Equal(t, KeyPhaseOne, KeyPhase(1).Bit()) require.Equal(t, KeyPhaseOne, KeyPhase(3).Bit()) require.Equal(t, KeyPhaseOne, KeyPhase(5).Bit()) } quic-go-0.59.0/internal/protocol/packet_number.go000066400000000000000000000035341513066070600220050ustar00rootroot00000000000000package protocol // A PacketNumber in QUIC type PacketNumber int64 // InvalidPacketNumber is a packet number that is never sent. // In QUIC, 0 is a valid packet number. const InvalidPacketNumber PacketNumber = -1 // PacketNumberLen is the length of the packet number in bytes type PacketNumberLen uint8 const ( // PacketNumberLen1 is a packet number length of 1 byte PacketNumberLen1 PacketNumberLen = 1 // PacketNumberLen2 is a packet number length of 2 bytes PacketNumberLen2 PacketNumberLen = 2 // PacketNumberLen3 is a packet number length of 3 bytes PacketNumberLen3 PacketNumberLen = 3 // PacketNumberLen4 is a packet number length of 4 bytes PacketNumberLen4 PacketNumberLen = 4 ) // DecodePacketNumber calculates the packet number based its length and the last seen packet number // This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3. func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber { expected := largest + 1 win := PacketNumber(1 << (length * 8)) hwin := win / 2 mask := win - 1 candidate := (expected & ^mask) | truncated if candidate <= expected-hwin && candidate < 1<<62-win { return candidate + win } if candidate > expected+hwin && candidate >= win { return candidate - win } return candidate } // PacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen { var numUnacked PacketNumber if largestAcked == InvalidPacketNumber { numUnacked = pn + 1 } else { numUnacked = pn - largestAcked } if numUnacked < 1<<(16-1) { return PacketNumberLen2 } if numUnacked < 1<<(24-1) { return PacketNumberLen3 } return PacketNumberLen4 } quic-go-0.59.0/internal/protocol/packet_number_test.go000066400000000000000000000046341513066070600230460ustar00rootroot00000000000000package protocol import ( "testing" "github.com/stretchr/testify/require" ) func TestInvalidPacketNumberIsSmallerThanAllValidPacketNumbers(t *testing.T) { require.Less(t, InvalidPacketNumber, PacketNumber(0)) } func TestPacketNumberLenHasCorrectValue(t *testing.T) { require.EqualValues(t, 1, PacketNumberLen1) require.EqualValues(t, 2, PacketNumberLen2) require.EqualValues(t, 3, PacketNumberLen3) require.EqualValues(t, 4, PacketNumberLen4) } func TestDecodePacketNumber(t *testing.T) { require.Equal(t, PacketNumber(255), DecodePacketNumber(PacketNumberLen1, 10, 255)) require.Equal(t, PacketNumber(0), DecodePacketNumber(PacketNumberLen1, 10, 0)) require.Equal(t, PacketNumber(256), DecodePacketNumber(PacketNumberLen1, 127, 0)) require.Equal(t, PacketNumber(256), DecodePacketNumber(PacketNumberLen1, 128, 0)) require.Equal(t, PacketNumber(256), DecodePacketNumber(PacketNumberLen1, 256+126, 0)) require.Equal(t, PacketNumber(512), DecodePacketNumber(PacketNumberLen1, 256+127, 0)) require.Equal(t, PacketNumber(0xffff), DecodePacketNumber(PacketNumberLen2, 0xffff, 0xffff)) require.Equal(t, PacketNumber(0xffff), DecodePacketNumber(PacketNumberLen2, 0xffff+1, 0xffff)) // example from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3 require.Equal(t, PacketNumber(0xa82f9b32), DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)) } func TestPacketNumberLengthForHeader(t *testing.T) { require.Equal(t, PacketNumberLen2, PacketNumberLengthForHeader(1, InvalidPacketNumber)) require.Equal(t, PacketNumberLen2, PacketNumberLengthForHeader(1<<15-2, InvalidPacketNumber)) require.Equal(t, PacketNumberLen3, PacketNumberLengthForHeader(1<<15-1, InvalidPacketNumber)) require.Equal(t, PacketNumberLen3, PacketNumberLengthForHeader(1<<23-2, InvalidPacketNumber)) require.Equal(t, PacketNumberLen4, PacketNumberLengthForHeader(1<<23-1, InvalidPacketNumber)) require.Equal(t, PacketNumberLen2, PacketNumberLengthForHeader(1<<15+9, 10)) require.Equal(t, PacketNumberLen3, PacketNumberLengthForHeader(1<<15+10, 10)) require.Equal(t, PacketNumberLen3, PacketNumberLengthForHeader(1<<23+99, 100)) require.Equal(t, PacketNumberLen4, PacketNumberLengthForHeader(1<<23+100, 100)) // examples from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.2 require.Equal(t, PacketNumberLen2, PacketNumberLengthForHeader(0xac5c02, 0xabe8b3)) require.Equal(t, PacketNumberLen3, PacketNumberLengthForHeader(0xace8fe, 0xabe8b3)) } quic-go-0.59.0/internal/protocol/params.go000066400000000000000000000203021513066070600204410ustar00rootroot00000000000000package protocol import "time" // DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. const DesiredReceiveBufferSize = (1 << 20) * 7 // 7 MB // DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use. const DesiredSendBufferSize = (1 << 20) * 7 // 7 MB // InitialPacketSize is the initial (before Path MTU discovery) maximum packet size used. const InitialPacketSize = 1280 // MaxCongestionWindowPackets is the maximum congestion window in packet. const MaxCongestionWindowPackets = 10000 // MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. const MaxUndecryptablePackets = 32 // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // This is the value that Chromium is using const ConnectionFlowControlMultiplier = 1.5 // DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb // DefaultInitialMaxData is the connection-level flow control window for receiving data const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData // DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB // DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB // WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client const WindowUpdateThreshold = 0.25 // DefaultMaxIncomingStreams is the maximum number of streams that a peer may open const DefaultMaxIncomingStreams = 100 // DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open const DefaultMaxIncomingUniStreams = 100 // MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. const MaxServerUnprocessedPackets = 1024 // MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. const MaxConnUnprocessedPackets = 256 // SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. // Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. const SkipPacketInitialPeriod PacketNumber = 256 // SkipPacketMaxPeriod is the maximum period length used for packet number skipping. const SkipPacketMaxPeriod PacketNumber = 128 * 1024 // MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. // If the queue is full, new connection attempts will be rejected. const MaxAcceptQueueSize = 32 // TokenValidity is the duration that a (non-retry) token is considered valid const TokenValidity = 24 * time.Hour // MaxOutstandingSentPackets is maximum number of packets saved for retransmission. // When reached, it imposes a soft limit on sending new packets: // Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets // MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. // When reached, no more packets will be sent. // This value *must* be larger than MaxOutstandingSentPackets. const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 // MaxNonAckElicitingAcks is the maximum number of packets containing an ACK, // but no ack-eliciting frames, that we send in a row const MaxNonAckElicitingAcks = 19 // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // prevents DoS attacks against the streamFrameSorter const MaxStreamFrameSorterGaps = 1000 // MinStreamFrameBufferSize is the minimum data length of a received STREAM frame // that we use the buffer for. This protects against a DoS where an attacker would send us // very small STREAM frames to consume a lot of memory. const MinStreamFrameBufferSize = 128 // MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. // If a packet has less than this number of bytes, we won't coalesce any more packets onto it. const MinCoalescedPacketSize = 128 // MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. // This limits the size of the ClientHello and Certificates that can be received. const MaxCryptoStreamOffset = 16 * (1 << 10) // MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout const MinRemoteIdleTimeout = 5 * time.Second // DefaultIdleTimeout is the default idle timeout const DefaultIdleTimeout = 30 * time.Second // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. const DefaultHandshakeIdleTimeout = 5 * time.Second // MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. // This avoids splitting up STREAM frames into small pieces, which has 2 advantages: // 1. it reduces the framing overhead // 2. it reduces the head-of-line blocking, when a packet is lost const MinStreamFrameSize ByteCount = 128 // MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames // we send after the handshake completes. const MaxPostHandshakeCryptoFrameSize = 1000 // MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. // It also serves as a limit for the packet history. // If at any point we keep track of more ranges, old ranges are discarded. // // This value also guarantees that ACK Range Count value in the ACK frame can be encoded // in a single byte varint. const MaxNumAckRanges = 64 // MinPacingDelay is the minimum duration that is used for packet pacing // If the packet packing frequency is higher, multiple packets might be sent at once. // Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth. const MinPacingDelay = time.Millisecond // DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections // if no other value is configured. const DefaultConnectionIDLength = 4 // MaxActiveConnectionIDs is the number of connection IDs that we're storing. const MaxActiveConnectionIDs = 4 // MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time. const MaxIssuedConnectionIDs = 6 // PacketsPerConnectionID is the number of packets we send using one connection ID. // If the peer provices us with enough new connection IDs, we switch to a new connection ID. const PacketsPerConnectionID = 10000 // AckDelayExponent is the ack delay exponent used when sending ACKs. const AckDelayExponent = 3 // Estimated timer granularity. // The loss detection timer will not be set to a value smaller than granularity. const TimerGranularity = time.Millisecond // MaxAckDelay is the maximum time by which we delay sending ACKs. const MaxAckDelay = 25 * time.Millisecond // MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. // This is the value that should be advertised to the peer. const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. const KeyUpdateInterval = 100 * 1000 // Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. const Max0RTTQueueingDuration = 100 * time.Millisecond // Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. const Max0RTTQueues = 32 // Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. // When a new connection is created, all buffered packets are passed to the connection immediately. // To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. // To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. const Max0RTTQueueLen = 31 quic-go-0.59.0/internal/protocol/params_test.go000066400000000000000000000005101513066070600214770ustar00rootroot00000000000000package protocol import ( "testing" "github.com/stretchr/testify/require" ) func TestPacketQueueCapacities(t *testing.T) { // Ensure that the session can queue more packets than the 0-RTT queue require.Greater(t, MaxConnUnprocessedPackets, Max0RTTQueueLen) require.Greater(t, MaxUndecryptablePackets, Max0RTTQueueLen) } quic-go-0.59.0/internal/protocol/perspective.go000066400000000000000000000007651513066070600215220ustar00rootroot00000000000000package protocol // Perspective determines if we're acting as a server or a client type Perspective int // the perspectives const ( PerspectiveServer Perspective = 1 PerspectiveClient Perspective = 2 ) // Opposite returns the perspective of the peer func (p Perspective) Opposite() Perspective { return 3 - p } func (p Perspective) String() string { switch p { case PerspectiveServer: return "server" case PerspectiveClient: return "client" default: return "invalid perspective" } } quic-go-0.59.0/internal/protocol/perspective_test.go000066400000000000000000000007501513066070600225530ustar00rootroot00000000000000package protocol import ( "testing" "github.com/stretchr/testify/require" ) func TestPerspectiveOpposite(t *testing.T) { require.Equal(t, PerspectiveServer, PerspectiveClient.Opposite()) require.Equal(t, PerspectiveClient, PerspectiveServer.Opposite()) } func TestPerspectiveStringer(t *testing.T) { require.Equal(t, "client", PerspectiveClient.String()) require.Equal(t, "server", PerspectiveServer.String()) require.Equal(t, "invalid perspective", Perspective(0).String()) } quic-go-0.59.0/internal/protocol/protocol.go000066400000000000000000000102711513066070600210230ustar00rootroot00000000000000package protocol import ( "fmt" "sync/atomic" "time" ) // The PacketType is the Long Header Type type PacketType uint8 const ( // PacketTypeInitial is the packet type of an Initial packet PacketTypeInitial PacketType = 1 + iota // PacketTypeRetry is the packet type of a Retry packet PacketTypeRetry // PacketTypeHandshake is the packet type of a Handshake packet PacketTypeHandshake // PacketType0RTT is the packet type of a 0-RTT packet PacketType0RTT ) func (t PacketType) String() string { switch t { case PacketTypeInitial: return "Initial" case PacketTypeRetry: return "Retry" case PacketTypeHandshake: return "Handshake" case PacketType0RTT: return "0-RTT Protected" default: return fmt.Sprintf("unknown packet type: %d", t) } } type ECN uint8 const ( ECNUnsupported ECN = iota ECNNon // 00 ECT1 // 01 ECT0 // 10 ECNCE // 11 ) func ParseECNHeaderBits(bits byte) ECN { switch bits { case 0: return ECNNon case 0b00000010: return ECT0 case 0b00000001: return ECT1 case 0b00000011: return ECNCE default: panic("invalid ECN bits") } } func (e ECN) ToHeaderBits() byte { //nolint:exhaustive // There are only 4 values. switch e { case ECNNon: return 0 case ECT0: return 0b00000010 case ECT1: return 0b00000001 case ECNCE: return 0b00000011 default: panic("ECN unsupported") } } func (e ECN) String() string { switch e { case ECNUnsupported: return "ECN unsupported" case ECNNon: return "Not-ECT" case ECT1: return "ECT(1)" case ECT0: return "ECT(0)" case ECNCE: return "CE" default: return fmt.Sprintf("invalid ECN value: %d", e) } } // A ByteCount in QUIC type ByteCount int64 type AtomicByteCount atomic.Int64 // MaxByteCount is the maximum value of a ByteCount const MaxByteCount = ByteCount(1<<62 - 1) // InvalidByteCount is an invalid byte count const InvalidByteCount ByteCount = -1 // A StatelessResetToken is a stateless reset token. type StatelessResetToken [16]byte // MaxPacketBufferSize maximum packet size of any QUIC packet, based on // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. // Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. const MaxPacketBufferSize = 1452 // MaxLargePacketBufferSize is used when using GSO const MaxLargePacketBufferSize = 20 * 1024 // MinInitialPacketSize is the minimum size an Initial packet is required to have. const MinInitialPacketSize = 1200 // MinUnknownVersionPacketSize is the minimum size a packet with an unknown version // needs to have in order to trigger a Version Negotiation packet. const MinUnknownVersionPacketSize = MinInitialPacketSize // MinStatelessResetSize is the minimum size of a stateless reset packet that we send const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */ // MinReceivedStatelessResetSize is the minimum size of a received stateless reset, // as specified in section 10.3 of RFC 9000. const MinReceivedStatelessResetSize = 5 + 16 // MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. const MinConnectionIDLenInitial = 8 // DefaultAckDelayExponent is the default ack delay exponent const DefaultAckDelayExponent = 3 // DefaultActiveConnectionIDLimit is the default active connection ID limit const DefaultActiveConnectionIDLimit = 2 // MaxAckDelayExponent is the maximum ack delay exponent const MaxAckDelayExponent = 20 // DefaultMaxAckDelay is the default max_ack_delay const DefaultMaxAckDelay = 25 * time.Millisecond // MaxMaxAckDelay is the maximum max_ack_delay const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond // MaxConnIDLen is the maximum length of the connection ID const MaxConnIDLen = 20 // InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using // AEAD_AES_128_GCM or AEAD_AES_265_GCM. const InvalidPacketLimitAES = 1 << 52 // InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. const InvalidPacketLimitChaCha = 1 << 36 quic-go-0.59.0/internal/protocol/protocol_test.go000066400000000000000000000024601513066070600220630ustar00rootroot00000000000000package protocol import ( "testing" "github.com/stretchr/testify/require" ) func TestLongHeaderPacketTypeStringer(t *testing.T) { require.Equal(t, "Initial", PacketTypeInitial.String()) require.Equal(t, "Retry", PacketTypeRetry.String()) require.Equal(t, "Handshake", PacketTypeHandshake.String()) require.Equal(t, "0-RTT Protected", PacketType0RTT.String()) require.Equal(t, "unknown packet type: 10", PacketType(10).String()) } func TestECNFromIPHeader(t *testing.T) { require.Equal(t, ECNNon, ParseECNHeaderBits(0)) require.Equal(t, ECT0, ParseECNHeaderBits(0b00000010)) require.Equal(t, ECT1, ParseECNHeaderBits(0b00000001)) require.Equal(t, ECNCE, ParseECNHeaderBits(0b00000011)) require.Panics(t, func() { ParseECNHeaderBits(0b1010101) }) } func TestECNConversionToIPHeaderBits(t *testing.T) { for _, v := range [...]ECN{ECNNon, ECT0, ECT1, ECNCE} { require.Equal(t, v, ParseECNHeaderBits(v.ToHeaderBits())) } require.Panics(t, func() { ECN(42).ToHeaderBits() }) } func TestECNStringer(t *testing.T) { require.Equal(t, "ECN unsupported", ECNUnsupported.String()) require.Equal(t, "Not-ECT", ECNNon.String()) require.Equal(t, "ECT(0)", ECT0.String()) require.Equal(t, "ECT(1)", ECT1.String()) require.Equal(t, "CE", ECNCE.String()) require.Equal(t, "invalid ECN value: 42", ECN(42).String()) } quic-go-0.59.0/internal/protocol/stream.go000066400000000000000000000060241513066070600204560ustar00rootroot00000000000000package protocol import "github.com/quic-go/quic-go/quicvarint" // StreamType encodes if this is a unidirectional or bidirectional stream type StreamType uint8 const ( // StreamTypeUni is a unidirectional stream StreamTypeUni StreamType = iota // StreamTypeBidi is a bidirectional stream StreamTypeBidi ) // InvalidPacketNumber is a stream ID that is invalid. // The first valid stream ID in QUIC is 0. const InvalidStreamID StreamID = -1 // StreamNum is the stream number type StreamNum int64 const ( // InvalidStreamNum is an invalid stream number. InvalidStreamNum = -1 // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames // and as the stream count in the transport parameters MaxStreamCount StreamNum = 1 << 60 // MaxStreamID is the maximum stream ID MaxStreamID StreamID = quicvarint.Max ) const ( // FirstOutgoingBidiStreamClient is the first bidirectional stream opened by the client FirstOutgoingBidiStreamClient StreamID = 0 // FirstOutgoingUniStreamClient is the first unidirectional stream opened by the client FirstOutgoingUniStreamClient StreamID = 2 // FirstOutgoingBidiStreamServer is the first bidirectional stream opened by the server FirstOutgoingBidiStreamServer StreamID = 1 // FirstOutgoingUniStreamServer is the first unidirectional stream opened by the server FirstOutgoingUniStreamServer StreamID = 3 ) const ( // FirstIncomingBidiStreamServer is the first bidirectional stream accepted by the server FirstIncomingBidiStreamServer = FirstOutgoingBidiStreamClient // FirstIncomingUniStreamServer is the first unidirectional stream accepted by the server FirstIncomingUniStreamServer = FirstOutgoingUniStreamClient // FirstIncomingBidiStreamClient is the first bidirectional stream accepted by the client FirstIncomingBidiStreamClient = FirstOutgoingBidiStreamServer // FirstIncomingUniStreamClient is the first unidirectional stream accepted by the client FirstIncomingUniStreamClient = FirstOutgoingUniStreamServer ) // StreamID calculates the stream ID. func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { if s == 0 { return InvalidStreamID } var first StreamID switch stype { case StreamTypeBidi: switch pers { case PerspectiveClient: first = 0 case PerspectiveServer: first = 1 } case StreamTypeUni: switch pers { case PerspectiveClient: first = 2 case PerspectiveServer: first = 3 } } return first + 4*StreamID(s-1) } // A StreamID in QUIC type StreamID int64 // InitiatedBy says if the stream was initiated by the client or by the server func (s StreamID) InitiatedBy() Perspective { if s%2 == 0 { return PerspectiveClient } return PerspectiveServer } // Type says if this is a unidirectional or bidirectional stream func (s StreamID) Type() StreamType { if s%4 >= 2 { return StreamTypeUni } return StreamTypeBidi } // StreamNum returns how many streams in total are below this // Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) func (s StreamID) StreamNum() StreamNum { return StreamNum(s/4) + 1 } quic-go-0.59.0/internal/protocol/stream_test.go000066400000000000000000000054611513066070600215210ustar00rootroot00000000000000package protocol import ( "testing" "github.com/stretchr/testify/require" ) func TestInvalidStreamIDSmallerThanAllValidStreamIDs(t *testing.T) { require.Less(t, InvalidStreamID, StreamID(0)) } func TestStreamIDInitiatedBy(t *testing.T) { require.Equal(t, PerspectiveClient, StreamID(4).InitiatedBy()) require.Equal(t, PerspectiveServer, StreamID(5).InitiatedBy()) require.Equal(t, PerspectiveClient, StreamID(6).InitiatedBy()) require.Equal(t, PerspectiveServer, StreamID(7).InitiatedBy()) } func TestStreamIDType(t *testing.T) { require.Equal(t, StreamTypeBidi, StreamID(4).Type()) require.Equal(t, StreamTypeBidi, StreamID(5).Type()) require.Equal(t, StreamTypeUni, StreamID(6).Type()) require.Equal(t, StreamTypeUni, StreamID(7).Type()) } func TestStreamIDStreamNum(t *testing.T) { require.Equal(t, StreamNum(1), StreamID(0).StreamNum()) require.Equal(t, StreamNum(1), StreamID(1).StreamNum()) require.Equal(t, StreamNum(1), StreamID(2).StreamNum()) require.Equal(t, StreamNum(1), StreamID(3).StreamNum()) require.Equal(t, StreamNum(3), StreamID(8).StreamNum()) require.Equal(t, StreamNum(3), StreamID(9).StreamNum()) require.Equal(t, StreamNum(3), StreamID(10).StreamNum()) require.Equal(t, StreamNum(3), StreamID(11).StreamNum()) } func TestStreamIDNumToStreamID(t *testing.T) { // 1st stream require.Equal(t, StreamID(0), StreamNum(1).StreamID(StreamTypeBidi, PerspectiveClient)) require.Equal(t, StreamID(1), StreamNum(1).StreamID(StreamTypeBidi, PerspectiveServer)) require.Equal(t, StreamID(2), StreamNum(1).StreamID(StreamTypeUni, PerspectiveClient)) require.Equal(t, StreamID(3), StreamNum(1).StreamID(StreamTypeUni, PerspectiveServer)) // 100th stream require.Equal(t, StreamID(396), StreamNum(100).StreamID(StreamTypeBidi, PerspectiveClient)) require.Equal(t, StreamID(397), StreamNum(100).StreamID(StreamTypeBidi, PerspectiveServer)) require.Equal(t, StreamID(398), StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)) require.Equal(t, StreamID(399), StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)) // 0 is not a valid stream number require.Equal(t, InvalidStreamID, StreamNum(0).StreamID(StreamTypeBidi, PerspectiveClient)) require.Equal(t, InvalidStreamID, StreamNum(0).StreamID(StreamTypeBidi, PerspectiveServer)) require.Equal(t, InvalidStreamID, StreamNum(0).StreamID(StreamTypeUni, PerspectiveClient)) require.Equal(t, InvalidStreamID, StreamNum(0).StreamID(StreamTypeUni, PerspectiveServer)) } func TestMaxStreamCountValue(t *testing.T) { const maxStreamID = StreamID(1<<62 - 1) for _, dir := range []StreamType{StreamTypeUni, StreamTypeBidi} { for _, pers := range []Perspective{PerspectiveClient, PerspectiveServer} { require.LessOrEqual(t, MaxStreamCount.StreamID(dir, pers), maxStreamID) require.Greater(t, (MaxStreamCount+1).StreamID(dir, pers), maxStreamID) } } } quic-go-0.59.0/internal/protocol/version.go000066400000000000000000000062031513066070600206470ustar00rootroot00000000000000package protocol import ( "crypto/rand" "encoding/binary" "fmt" "math" mrand "math/rand/v2" "slices" "sync" ) // Version is a version number as int type Version uint32 // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions const ( gquicVersion0 = 0x51303030 maxGquicVersion = 0x51303439 ) // The version numbers, making grepping easier const ( VersionUnknown Version = math.MaxUint32 versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version Version1 Version = 0x1 Version2 Version = 0x6b3343cf ) // SupportedVersions lists the versions that the server supports // must be in sorted descending order var SupportedVersions = []Version{Version1, Version2} // IsValidVersion says if the version is known to quic-go func IsValidVersion(v Version) bool { return v == Version1 || IsSupportedVersion(SupportedVersions, v) } func (vn Version) String() string { switch vn { case VersionUnknown: return "unknown" case versionDraft29: return "draft-29" case Version1: return "v1" case Version2: return "v2" default: if vn.isGQUIC() { return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) } return fmt.Sprintf("%#x", uint32(vn)) } } func (vn Version) isGQUIC() bool { return vn > gquicVersion0 && vn <= maxGquicVersion } func (vn Version) toGQUICVersion() int { return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) } // IsSupportedVersion returns true if the server supports this version func IsSupportedVersion(supported []Version, v Version) bool { return slices.Contains(supported, v) } // ChooseSupportedVersion finds the best version in the overlap of ours and theirs // ours is a slice of versions that we support, sorted by our preference (descending) // theirs is a slice of versions offered by the peer. The order does not matter. // The bool returned indicates if a matching version was found. func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) { for _, ourVer := range ours { if slices.Contains(theirs, ourVer) { return ourVer, true } } return 0, false } var ( versionNegotiationMx sync.Mutex versionNegotiationRand mrand.Rand ) func init() { var seed [16]byte rand.Read(seed[:]) versionNegotiationRand = *mrand.New(mrand.NewPCG( binary.BigEndian.Uint64(seed[:8]), binary.BigEndian.Uint64(seed[8:]), )) } // generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a) func generateReservedVersion() Version { var b [4]byte binary.BigEndian.PutUint32(b[:], versionNegotiationRand.Uint32()) return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa) } // GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position. // It doesn't modify the supported slice. func GetGreasedVersions(supported []Version) []Version { versionNegotiationMx.Lock() defer versionNegotiationMx.Unlock() randPos := versionNegotiationRand.IntN(len(supported) + 1) greased := make([]Version, len(supported)+1) copy(greased, supported[:randPos]) greased[randPos] = generateReservedVersion() copy(greased[randPos+1:], supported[randPos:]) return greased } quic-go-0.59.0/internal/protocol/version_test.go000066400000000000000000000102231513066070600217030ustar00rootroot00000000000000package protocol import ( "slices" "testing" "github.com/stretchr/testify/require" ) func TestValidVersion(t *testing.T) { require.False(t, IsValidVersion(VersionUnknown)) require.False(t, IsValidVersion(versionDraft29)) require.True(t, IsValidVersion(Version1)) require.True(t, IsValidVersion(Version2)) require.False(t, IsValidVersion(1234)) } func TestVersionStringRepresentation(t *testing.T) { require.Equal(t, "unknown", VersionUnknown.String()) require.Equal(t, "draft-29", versionDraft29.String()) require.Equal(t, "v1", Version1.String()) require.Equal(t, "v2", Version2.String()) // check with unsupported version numbers from the wiki require.Equal(t, "gQUIC 9", Version(0x51303039).String()) require.Equal(t, "gQUIC 13", Version(0x51303133).String()) require.Equal(t, "gQUIC 25", Version(0x51303235).String()) require.Equal(t, "gQUIC 48", Version(0x51303438).String()) require.Equal(t, "0x1234567", Version(0x01234567).String()) } func TestRecognizesSupportedVersions(t *testing.T) { require.False(t, IsSupportedVersion(SupportedVersions, 0)) require.False(t, IsSupportedVersion(SupportedVersions, maxGquicVersion)) require.True(t, IsSupportedVersion(SupportedVersions, SupportedVersions[0])) require.True(t, IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])) } func TestVersionSelection(t *testing.T) { tests := []struct { name string supportedVersions []Version otherVersions []Version expectedVersion Version expectedOK bool }{ { name: "finds matching version", supportedVersions: []Version{1, 2, 3}, otherVersions: []Version{6, 5, 4, 3}, expectedVersion: 3, expectedOK: true, }, { name: "picks preferred version", supportedVersions: []Version{2, 1, 3}, otherVersions: []Version{3, 6, 1, 8, 2, 10}, expectedVersion: 2, expectedOK: true, }, { name: "no matching version", supportedVersions: []Version{1}, otherVersions: []Version{2}, expectedOK: false, }, { name: "empty supported versions", supportedVersions: []Version{}, otherVersions: []Version{1, 2}, expectedOK: false, }, { name: "empty other versions", supportedVersions: []Version{102, 101}, otherVersions: []Version{}, expectedOK: false, }, { name: "both empty", supportedVersions: []Version{}, otherVersions: []Version{}, expectedOK: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ver, ok := ChooseSupportedVersion(tt.supportedVersions, tt.otherVersions) require.Equal(t, tt.expectedOK, ok) if tt.expectedOK { require.Equal(t, tt.expectedVersion, ver) } }) } } func isReservedVersion(v Version) bool { return v&0x0f0f0f0f == 0x0a0a0a0a } func TestVersionGreasing(t *testing.T) { // adding to an empty slice greased := GetGreasedVersions([]Version{}) require.Len(t, greased, 1) require.True(t, isReservedVersion(greased[0])) // make sure that the greased versions are distinct, // allowing for a small number of duplicates var versions []Version for range 25 { versions = GetGreasedVersions(versions) } slices.Sort(versions) var numDuplicates int for i, v := range versions { require.True(t, isReservedVersion(v)) if i > 0 && versions[i-1] == v { numDuplicates++ } } require.LessOrEqual(t, numDuplicates, 3) // adding it somewhere in a slice of supported versions supported := []Version{10, 18, 29} for _, v := range supported { require.False(t, isReservedVersion(v)) } var greasedVersionFirst, greasedVersionLast, greasedVersionMiddle int for range 100 { greased := GetGreasedVersions(supported) require.Len(t, greased, 4) var j int for i, v := range greased { if isReservedVersion(v) { if i == 0 { greasedVersionFirst++ } if i == len(greased)-1 { greasedVersionLast++ } greasedVersionMiddle++ continue } require.Equal(t, supported[j], v) j++ } } require.NotZero(t, greasedVersionFirst) require.NotZero(t, greasedVersionLast) require.NotZero(t, greasedVersionMiddle) } quic-go-0.59.0/internal/qerr/000077500000000000000000000000001513066070600157425ustar00rootroot00000000000000quic-go-0.59.0/internal/qerr/error_codes.go000066400000000000000000000047171513066070600206100ustar00rootroot00000000000000package qerr import ( "crypto/tls" "fmt" ) // TransportErrorCode is a QUIC transport error. type TransportErrorCode uint64 // The error codes defined by QUIC const ( NoError TransportErrorCode = 0x0 InternalError TransportErrorCode = 0x1 ConnectionRefused TransportErrorCode = 0x2 FlowControlError TransportErrorCode = 0x3 StreamLimitError TransportErrorCode = 0x4 StreamStateError TransportErrorCode = 0x5 FinalSizeError TransportErrorCode = 0x6 FrameEncodingError TransportErrorCode = 0x7 TransportParameterError TransportErrorCode = 0x8 ConnectionIDLimitError TransportErrorCode = 0x9 ProtocolViolation TransportErrorCode = 0xa InvalidToken TransportErrorCode = 0xb ApplicationErrorErrorCode TransportErrorCode = 0xc CryptoBufferExceeded TransportErrorCode = 0xd KeyUpdateError TransportErrorCode = 0xe AEADLimitReached TransportErrorCode = 0xf NoViablePathError TransportErrorCode = 0x10 ) func (e TransportErrorCode) IsCryptoError() bool { return e >= 0x100 && e < 0x200 } // Message is a description of the error. // It only returns a non-empty string for crypto errors. func (e TransportErrorCode) Message() string { if !e.IsCryptoError() { return "" } return tls.AlertError(e - 0x100).Error() } func (e TransportErrorCode) String() string { switch e { case NoError: return "NO_ERROR" case InternalError: return "INTERNAL_ERROR" case ConnectionRefused: return "CONNECTION_REFUSED" case FlowControlError: return "FLOW_CONTROL_ERROR" case StreamLimitError: return "STREAM_LIMIT_ERROR" case StreamStateError: return "STREAM_STATE_ERROR" case FinalSizeError: return "FINAL_SIZE_ERROR" case FrameEncodingError: return "FRAME_ENCODING_ERROR" case TransportParameterError: return "TRANSPORT_PARAMETER_ERROR" case ConnectionIDLimitError: return "CONNECTION_ID_LIMIT_ERROR" case ProtocolViolation: return "PROTOCOL_VIOLATION" case InvalidToken: return "INVALID_TOKEN" case ApplicationErrorErrorCode: return "APPLICATION_ERROR" case CryptoBufferExceeded: return "CRYPTO_BUFFER_EXCEEDED" case KeyUpdateError: return "KEY_UPDATE_ERROR" case AEADLimitReached: return "AEAD_LIMIT_REACHED" case NoViablePathError: return "NO_VIABLE_PATH" default: if e.IsCryptoError() { return fmt.Sprintf("CRYPTO_ERROR %#x", uint16(e)) } return fmt.Sprintf("unknown error code: %#x", uint16(e)) } } quic-go-0.59.0/internal/qerr/errorcodes_test.go000066400000000000000000000024351513066070600215030ustar00rootroot00000000000000package qerr import ( "go/ast" "go/parser" "go/token" "path" "runtime" "strconv" "testing" "github.com/stretchr/testify/require" ) func TestTransportErrorCodeStringer(t *testing.T) { _, thisfile, _, ok := runtime.Caller(0) require.True(t, ok, "Failed to get current frame") filename := path.Join(path.Dir(thisfile), "error_codes.go") fileAst, err := parser.ParseFile(token.NewFileSet(), filename, nil, 0) require.NoError(t, err) constSpecs := fileAst.Decls[2].(*ast.GenDecl).Specs require.Greater(t, len(constSpecs), 4, "Expected more than 4 constants") for _, c := range constSpecs { valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value val, err := strconv.ParseInt(valString, 0, 64) require.NoError(t, err) require.NotEqual(t, "unknown error code", TransportErrorCode(val).String()) } // test that there's a string representation for unknown error codes require.Equal(t, "unknown error code: 0x1337", TransportErrorCode(0x1337).String()) } func TestIsCryptoError(t *testing.T) { for i := 0; i < 0x100; i++ { require.False(t, TransportErrorCode(i).IsCryptoError()) } for i := 0x100; i < 0x200; i++ { require.True(t, TransportErrorCode(i).IsCryptoError()) } for i := 0x200; i < 0x300; i++ { require.False(t, TransportErrorCode(i).IsCryptoError()) } } quic-go-0.59.0/internal/qerr/errors.go000066400000000000000000000073771513066070600176230ustar00rootroot00000000000000package qerr import ( "fmt" "net" "github.com/quic-go/quic-go/internal/protocol" ) var ( ErrHandshakeTimeout = &HandshakeTimeoutError{} ErrIdleTimeout = &IdleTimeoutError{} ) type TransportError struct { Remote bool FrameType uint64 ErrorCode TransportErrorCode ErrorMessage string error error // only set for local errors, sometimes } var _ error = &TransportError{} // NewLocalCryptoError create a new TransportError instance for a crypto error func NewLocalCryptoError(tlsAlert uint8, err error) *TransportError { return &TransportError{ ErrorCode: 0x100 + TransportErrorCode(tlsAlert), error: err, } } func (e *TransportError) Error() string { str := fmt.Sprintf("%s (%s)", e.ErrorCode.String(), getRole(e.Remote)) if e.FrameType != 0 { str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) } msg := e.ErrorMessage if len(msg) == 0 && e.error != nil { msg = e.error.Error() } if len(msg) == 0 { msg = e.ErrorCode.Message() } if len(msg) == 0 { return str } return str + ": " + msg } func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} } func (e *TransportError) Is(target error) bool { t, ok := target.(*TransportError) return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote } // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 // A StreamErrorCode is an error code used to cancel streams. type StreamErrorCode uint64 type ApplicationError struct { Remote bool ErrorCode ApplicationErrorCode ErrorMessage string } var _ error = &ApplicationError{} func (e *ApplicationError) Error() string { if len(e.ErrorMessage) == 0 { return fmt.Sprintf("Application error %#x (%s)", e.ErrorCode, getRole(e.Remote)) } return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage) } func (e *ApplicationError) Unwrap() error { return net.ErrClosed } func (e *ApplicationError) Is(target error) bool { t, ok := target.(*ApplicationError) return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote } type IdleTimeoutError struct{} var _ error = &IdleTimeoutError{} func (e *IdleTimeoutError) Timeout() bool { return true } func (e *IdleTimeoutError) Temporary() bool { return false } func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } func (e *IdleTimeoutError) Unwrap() error { return net.ErrClosed } type HandshakeTimeoutError struct{} var _ error = &HandshakeTimeoutError{} func (e *HandshakeTimeoutError) Timeout() bool { return true } func (e *HandshakeTimeoutError) Temporary() bool { return false } func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } func (e *HandshakeTimeoutError) Unwrap() error { return net.ErrClosed } // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. type VersionNegotiationError struct { Ours []protocol.Version Theirs []protocol.Version } func (e *VersionNegotiationError) Error() string { return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) } func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed } // A StatelessResetError occurs when we receive a stateless reset. type StatelessResetError struct{} var _ net.Error = &StatelessResetError{} func (e *StatelessResetError) Error() string { return "received a stateless reset" } func (e *StatelessResetError) Unwrap() error { return net.ErrClosed } func (e *StatelessResetError) Timeout() bool { return false } func (e *StatelessResetError) Temporary() bool { return true } func getRole(remote bool) string { if remote { return "remote" } return "local" } quic-go-0.59.0/internal/qerr/errors_test.go000066400000000000000000000131561513066070600206520ustar00rootroot00000000000000package qerr import ( "errors" "fmt" "net" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTransportError(t *testing.T) { require.True(t, errors.Is(&TransportError{}, net.ErrClosed)) require.True(t, errors.Is( &TransportError{Remote: true, ErrorCode: 1, FrameType: 2}, &TransportError{Remote: true, ErrorCode: 1, FrameType: 2}, )) require.True(t, errors.Is(&TransportError{ErrorCode: 0x101}, &TransportError{ErrorCode: 0x101})) require.False(t, errors.Is(&TransportError{}, &TransportError{ErrorCode: 0x101})) require.False(t, errors.Is(&TransportError{}, &TransportError{FrameType: 0x1337})) require.False(t, errors.Is(&TransportError{Remote: false}, &TransportError{Remote: true})) } func TestTransportErrorStringer(t *testing.T) { t.Run("with error message", func(t *testing.T) { err := &TransportError{ ErrorCode: FlowControlError, ErrorMessage: "foobar", } require.Equal(t, "FLOW_CONTROL_ERROR (local): foobar", err.Error()) }) t.Run("without error message", func(t *testing.T) { err := &TransportError{ErrorCode: FlowControlError} require.Equal(t, "FLOW_CONTROL_ERROR (local)", err.Error()) }) t.Run("with frame type", func(t *testing.T) { err := &TransportError{ Remote: true, ErrorCode: FlowControlError, FrameType: 0x1337, } require.Equal(t, "FLOW_CONTROL_ERROR (remote) (frame type: 0x1337)", err.Error()) }) t.Run("with frame type and error message", func(t *testing.T) { err := &TransportError{ ErrorCode: FlowControlError, FrameType: 0x1337, ErrorMessage: "foobar", } require.Equal(t, "FLOW_CONTROL_ERROR (local) (frame type: 0x1337): foobar", err.Error()) }) } type myError int var _ error = myError(0) func (e myError) Error() string { return fmt.Sprintf("my error %d", e) } func TestCryptoError(t *testing.T) { var myErr myError err := NewLocalCryptoError(0x42, myError(1337)) require.True(t, errors.As(err, &myErr)) require.Equal(t, myError(1337), myErr) err = NewLocalCryptoError(0x42, assert.AnError) require.True(t, errors.Is(err, assert.AnError)) require.True(t, errors.Is( NewLocalCryptoError(0x42, assert.AnError), NewLocalCryptoError(0x42, assert.AnError), )) require.False(t, errors.Is( NewLocalCryptoError(0x42, assert.AnError), NewLocalCryptoError(0x43, assert.AnError), )) } func TestCryptoErrorStringer(t *testing.T) { t.Run("with error message", func(t *testing.T) { myErr := myError(1337) err := NewLocalCryptoError(0x42, myErr) require.Equal(t, "CRYPTO_ERROR 0x142 (local): my error 1337", err.Error()) }) t.Run("without error message", func(t *testing.T) { err := NewLocalCryptoError(0x2a, nil) require.Equal(t, "CRYPTO_ERROR 0x12a (local): tls: bad certificate", err.Error()) }) } func TestApplicationError(t *testing.T) { require.True(t, errors.Is(&ApplicationError{}, net.ErrClosed)) require.True(t, errors.Is( &ApplicationError{ErrorCode: 1, Remote: true}, &ApplicationError{ErrorCode: 1, Remote: true}, )) require.True(t, errors.Is(&ApplicationError{ErrorCode: 0x101}, &ApplicationError{ErrorCode: 0x101})) require.False(t, errors.Is(&ApplicationError{}, &ApplicationError{ErrorCode: 0x101})) require.False(t, errors.Is(&ApplicationError{Remote: false}, &ApplicationError{Remote: true})) } func TestApplicationErrorStringer(t *testing.T) { t.Run("with error message", func(t *testing.T) { err := &ApplicationError{ ErrorCode: 0x42, ErrorMessage: "foobar", } require.Equal(t, "Application error 0x42 (local): foobar", err.Error()) }) t.Run("without error message", func(t *testing.T) { err := &ApplicationError{ ErrorCode: 0x42, Remote: true, } require.Equal(t, "Application error 0x42 (remote)", err.Error()) }) } func TestHandshakeTimeoutError(t *testing.T) { require.True(t, errors.Is(&HandshakeTimeoutError{}, &HandshakeTimeoutError{})) require.False(t, errors.Is(&HandshakeTimeoutError{}, &IdleTimeoutError{})) //nolint:staticcheck // SA1021: we need to assign to an interface here var err error err = &HandshakeTimeoutError{} nerr, ok := err.(net.Error) require.True(t, ok) require.True(t, nerr.Timeout()) require.Equal(t, "timeout: handshake did not complete in time", err.Error()) require.True(t, errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)) } func TestIdleTimeoutError(t *testing.T) { require.True(t, errors.Is(&IdleTimeoutError{}, &IdleTimeoutError{})) require.False(t, errors.Is(&IdleTimeoutError{}, &HandshakeTimeoutError{})) //nolint:staticcheck // SA1021: we need to assign to an interface here var err error err = &IdleTimeoutError{} nerr, ok := err.(net.Error) require.True(t, ok) require.True(t, nerr.Timeout()) require.Equal(t, "timeout: no recent network activity", err.Error()) require.True(t, errors.Is(&IdleTimeoutError{}, net.ErrClosed)) } func TestVersionNegotiationErrorString(t *testing.T) { err := &VersionNegotiationError{ Ours: []protocol.Version{2, 3}, Theirs: []protocol.Version{4, 5, 6}, } require.Equal(t, "no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])", err.Error()) require.True(t, errors.Is(&VersionNegotiationError{}, net.ErrClosed)) } func TestStatelessResetError(t *testing.T) { require.Equal(t, "received a stateless reset", (&StatelessResetError{}).Error()) require.True(t, errors.Is(&StatelessResetError{}, &StatelessResetError{})) //nolint:staticcheck // SA1021: we need to assign to an interface here var err error err = &StatelessResetError{} nerr, ok := err.(net.Error) require.True(t, ok) require.False(t, nerr.Timeout()) require.True(t, errors.Is(&StatelessResetError{}, net.ErrClosed)) } quic-go-0.59.0/internal/qtls/000077500000000000000000000000001513066070600157545ustar00rootroot00000000000000quic-go-0.59.0/internal/qtls/cipher_suite.go000066400000000000000000000034051513066070600207700ustar00rootroot00000000000000package qtls import ( "crypto/tls" "fmt" "unsafe" ) //go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 var cipherSuitesTLS13 []unsafe.Pointer //go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13 var defaultCipherSuitesTLS13 []uint16 //go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES var defaultCipherSuitesTLS13NoAES []uint16 var cipherSuitesModified bool // SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls // such that it only contains the cipher suite with the chosen id. // The reset function returned resets them back to the original value. func SetCipherSuite(id uint16) (reset func()) { if cipherSuitesModified { panic("cipher suites modified multiple times without resetting") } cipherSuitesModified = true origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. switch id { case tls.TLS_AES_128_GCM_SHA256: cipherSuitesTLS13 = cipherSuitesTLS13[:1] case tls.TLS_CHACHA20_POLY1305_SHA256: cipherSuitesTLS13 = cipherSuitesTLS13[1:2] case tls.TLS_AES_256_GCM_SHA384: cipherSuitesTLS13 = cipherSuitesTLS13[2:] default: panic(fmt.Sprintf("unexpected cipher suite: %d", id)) } defaultCipherSuitesTLS13 = []uint16{id} defaultCipherSuitesTLS13NoAES = []uint16{id} return func() { cipherSuitesTLS13 = origCipherSuitesTLS13 defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES cipherSuitesModified = false } } quic-go-0.59.0/internal/qtls/cipher_suite_test.go000066400000000000000000000025241513066070600220300ustar00rootroot00000000000000package qtls import ( "crypto/tls" "fmt" "net" "testing" "github.com/quic-go/quic-go/internal/testdata" "github.com/stretchr/testify/require" ) func TestCipherSuiteSelection(t *testing.T) { t.Run("TLS_AES_128_GCM_SHA256", func(t *testing.T) { testCipherSuiteSelection(t, tls.TLS_AES_128_GCM_SHA256) }) t.Run("TLS_CHACHA20_POLY1305_SHA256", func(t *testing.T) { testCipherSuiteSelection(t, tls.TLS_CHACHA20_POLY1305_SHA256) }) t.Run("TLS_AES_256_GCM_SHA384", func(t *testing.T) { testCipherSuiteSelection(t, tls.TLS_AES_256_GCM_SHA384) }) } func testCipherSuiteSelection(t *testing.T, cs uint16) { reset := SetCipherSuite(cs) defer reset() ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig()) require.NoError(t, err) defer ln.Close() done := make(chan struct{}) go func() { defer close(done) conn, err := ln.Accept() require.NoError(t, err) _, err = conn.Read(make([]byte, 10)) require.NoError(t, err) require.Equal(t, cs, conn.(*tls.Conn).ConnectionState().CipherSuite) }() conn, err := tls.Dial( "tcp4", fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), &tls.Config{RootCAs: testdata.GetRootCA()}, ) require.NoError(t, err) _, err = conn.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, cs, conn.ConnectionState().CipherSuite) require.NoError(t, conn.Close()) <-done } quic-go-0.59.0/internal/synctest/000077500000000000000000000000001513066070600166455ustar00rootroot00000000000000quic-go-0.59.0/internal/synctest/synctest_go124.go000066400000000000000000000004351513066070600217660ustar00rootroot00000000000000//go:build go1.24 && !go1.25 package synctest import ( "testing" "testing/synctest" ) func Test(t *testing.T, f func(t *testing.T)) { synctest.Run(func() { f(t) }) } func Wait() { //nolint:govet // the CI configuration sets the GOEXPERIMENT=synctest flag synctest.Wait() } quic-go-0.59.0/internal/synctest/synctest_go125.go000066400000000000000000000002711513066070600217650ustar00rootroot00000000000000//go:build go1.25 package synctest import ( "testing" "testing/synctest" ) func Test(t *testing.T, f func(t *testing.T)) { synctest.Test(t, f) } func Wait() { synctest.Wait() } quic-go-0.59.0/internal/testdata/000077500000000000000000000000001513066070600166025ustar00rootroot00000000000000quic-go-0.59.0/internal/testdata/ca.pem000066400000000000000000000020051513066070600176650ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIICzDCCAbQCCQDA+rLymNnfJzANBgkqhkiG9w0BAQsFADAoMSYwJAYDVQQKDB1x dWljLWdvIENlcnRpZmljYXRlIEF1dGhvcml0eTAeFw0yMDA4MTgwOTIxMzVaFw0z MDA4MTYwOTIxMzVaMCgxJjAkBgNVBAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0 aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1OcsYrVaSDfh iDppl6oteVspOY3yFb96T9Y/biaGPJAkBO9VGKcqwOUPmUeiWpedRAUB9LE7Srs6 qBX4mnl90Icjp8jbIs5cPgIWLkIu8Qm549RghFzB3bn+EmCQSe4cxvyDMN3ndClp 3YMXpZgXWgJGiPOylVi/OwHDdWDBorw4hvry+6yDtpQo2TuI2A/xtxXPT7BgsEJD WGffdgZOYXChcFA0c1XVLIYlu2w2JhxS8c2TUF6uSDlmcoONNKVoiNCuu1Z9MorS Qmg7a2G7dSPu123KcTcSQFcmJrt+1G81gOBtHB69kacD8xDmgksj09h/ODPL/gIU 1ZcU2ci1/QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB0Tb1JbLXp/BvWovSAhO/j wG7UEaUA1rCtkDB+fV2HS9bxCbV5eErdg8AMHKgB51ygUrq95vm/baZmUILr84XK uTEoxxrw5S9Z7SrhtbOpKCumoSeTsCPjDvCcwFExHv4XHFk+CPqZwbMHueVIMT0+ nGWss/KecCPdJLdnUgMRz0tIuXzkoRuOiUiZfUeyBNVNbDFSrLigYshTeAPGaYjX CypoHxkeS93nWfOMUu8FTYLYkvGMU5i076zDoFGKJiEtbjSiNW+Hei7u2aSEuCzp qyTKzYPWYffAq3MM2MKJgZdL04e9GEGeuce/qhM1o3q77aI/XJImwEDdut2LDec1 -----END CERTIFICATE----- quic-go-0.59.0/internal/testdata/cert.go000066400000000000000000000023041513066070600200650ustar00rootroot00000000000000package testdata import ( "crypto/tls" "crypto/x509" "os" "path" "runtime" ) var certPath string func init() { _, filename, _, ok := runtime.Caller(0) if !ok { panic("Failed to get current frame") } certPath = path.Dir(filename) } // GetCertificatePaths returns the paths to certificate and key func GetCertificatePaths() (string, string) { return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key") } // GetTLSConfig returns a tls config for quic.clemente.io func GetTLSConfig() *tls.Config { cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) if err != nil { panic(err) } return &tls.Config{ MinVersion: tls.VersionTLS13, Certificates: []tls.Certificate{cert}, } } // AddRootCA adds the root CA certificate to a cert pool func AddRootCA(certPool *x509.CertPool) { caCertPath := path.Join(certPath, "ca.pem") caCertRaw, err := os.ReadFile(caCertPath) if err != nil { panic(err) } if ok := certPool.AppendCertsFromPEM(caCertRaw); !ok { panic("Could not add root ceritificate to pool.") } } // GetRootCA returns an x509.CertPool containing (only) the CA certificate func GetRootCA() *x509.CertPool { pool := x509.NewCertPool() AddRootCA(pool) return pool } quic-go-0.59.0/internal/testdata/cert.pem000066400000000000000000000020221513066070600202360ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIC1TCCAb2gAwIBAgIJAK2fcqC0BVA7MA0GCSqGSIb3DQEBCwUAMCgxJjAkBgNV BAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTIwMDgxODA5MjEz NVoXDTMwMDgxNjA5MjEzNVowEjEQMA4GA1UECgwHcXVpYy1nbzCCASIwDQYJKoZI hvcNAQEBBQADggEPADCCAQoCggEBAN/YwrigSXdJCL/bdBGhb0UpqtU8H+krV870 +w1yCSykLImH8x3qHZEXt9sr/vgjcJoV6Z15RZmnbEqnAx84sIClIBoIgnk0VPxu WF+/U/dElbftCfYcfJAddhRckdmGB+yb3Wogb32UJ+q3my++h6NjHsYb+OwpJPnQ meXjOE7Kkf+bXfFywHF3R8kzVdh5JUFYeKbxYmYgxRps1YTsbCrZCrSy1CbQ9FJw Wg5C8t+7yvVFmOeWPECypBCz2xS2mu+kycMNIjIWMl0SL7oVM5cBkRKPeVIG/KcM i5+/4lRSLoPh0Txh2TKBWfpzLbIOdPU8/O7cAukIGWx0XsfHUQMCAwEAAaMYMBYw FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAyxxvebdMz shp5pt1SxMOSXbo8sTa1cpaf2rTmb4nxjXs6KPBEn53hSBz9bhe5wXE4f94SHadf 636rLh3d75KgrLUwO9Yq0HfCxMo1jUV/Ug++XwcHCI9vk58Tk/H4hqEM6C8RrdTj fYeuegQ0/oNLJ4uTw2P2A8TJbL6FC2dcICEAvUGZUcVyZ8m8tHXNRYYh6MZ7ubCh hinvL+AA5fY6EVlc5G/P4DN6fYxGn1cFNbiL4uZP4+W3dOmP+NV0YV9ihTyMzz0R vSoOZ9FeVkyw8EhMb3LoyXYKazvJy2VQST1ltzAGit9RiM1Gv4vuna74WsFzrn1U A/TbaR0ih/qG -----END CERTIFICATE----- quic-go-0.59.0/internal/testdata/cert_test.go000066400000000000000000000011241513066070600211230ustar00rootroot00000000000000package testdata import ( "crypto/tls" "io" "testing" "github.com/stretchr/testify/require" ) func TestCertificates(t *testing.T) { ln, err := tls.Listen("tcp", "localhost:4433", GetTLSConfig()) require.NoError(t, err) go func() { conn, err := ln.Accept() require.NoError(t, err) defer conn.Close() _, err = conn.Write([]byte("foobar")) require.NoError(t, err) }() conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) require.NoError(t, err) data, err := io.ReadAll(conn) require.NoError(t, err) require.Equal(t, "foobar", string(data)) } quic-go-0.59.0/internal/testdata/generate_key.sh000077500000000000000000000012441513066070600216040ustar00rootroot00000000000000#!/bin/bash set -e echo "Generating CA key and certificate:" openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:2048 \ -keyout ca.key -out ca.pem \ -subj "/O=quic-go Certificate Authority/" echo "Generating CSR" openssl req -out cert.csr -new -newkey rsa:2048 -nodes -keyout priv.key \ -subj "/O=quic-go/" echo "Sign certificate:" openssl x509 -req -sha256 -days 3650 -in cert.csr -out cert.pem \ -CA ca.pem -CAkey ca.key -CAcreateserial \ -extfile <(printf "subjectAltName=DNS:localhost") # debug output the certificate openssl x509 -noout -text -in cert.pem # we don't need the CA key, the serial number and the CSR any more rm ca.key cert.csr ca.srl quic-go-0.59.0/internal/testdata/priv.key000066400000000000000000000032501513066070600202740ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDf2MK4oEl3SQi/ 23QRoW9FKarVPB/pK1fO9PsNcgkspCyJh/Md6h2RF7fbK/74I3CaFemdeUWZp2xK pwMfOLCApSAaCIJ5NFT8blhfv1P3RJW37Qn2HHyQHXYUXJHZhgfsm91qIG99lCfq t5svvoejYx7GG/jsKST50Jnl4zhOypH/m13xcsBxd0fJM1XYeSVBWHim8WJmIMUa bNWE7Gwq2Qq0stQm0PRScFoOQvLfu8r1RZjnljxAsqQQs9sUtprvpMnDDSIyFjJd Ei+6FTOXAZESj3lSBvynDIufv+JUUi6D4dE8YdkygVn6cy2yDnT1PPzu3ALpCBls dF7Hx1EDAgMBAAECggEBAMm+mLDBdbUWk9YmuZNyRdC13wvT5obF05vo26OglXgw dxt09b6OVBuCnuff3SpS9pdJDIYq2HnFlSorH/sxopIvQKF17fHDIp1n7ipNTCXd IHrmHkY8Il/YzaVIUQMVc2rih0mw9greTqOS20DKnYC6QvAWIeDmrDaitTGl+ge3 hm7e2lsgZi13R6fTNwQs9geEQSGzP2k7bFceHQFDChOYiQraR5+VZZ8S8AMGjk47 AUa5EsKeUe6O9t2xuDSFxzYz5eadOAiErKGDos5KXXr3VQgFcC8uPEFFjcJ/yl+8 tOe4iLeVwGSDJhTAThdR2deJOjaDcarWM7ixmxA3DAECgYEA/WVwmY4gWKwv49IJ Jnh1Gu93P772GqliMNpukdjTI+joQxfl4jRSt2hk4b1KRwyT9aaKfvdz0HFlXo/r 9NVSAYT3/3vbcw61bfvPhhtz44qRAAKua6b5cUM6XqxVt1hqdP8lrf/blvA5ln+u O51S8+wpxZMuqKz/29zdWSG6tAMCgYEA4iWXMXX9dZajI6abVkWwuosvOakXdLk4 tUy7zd+JPF7hmUzzj2gtg4hXoiQPAOi+GY3TX+1Nza3s1LD7iWaXSKeOWvvligw9 Q/wVTNW2P1+tdhScJf9QudzW69xOm5HNBgx9uWV2cHfjC12vg5aTH0k5axvaq15H 9WBXlH5q3wECgYBYoYGYBDFmMpvxmMagkSOMz1OrlVSpkLOKmOxx0SBRACc1SIec 7mY8RqR6nOX9IfYixyTMMittLiyhvb9vfKnZZDQGRcFFZlCpbplws+t+HDqJgWaW uumm5zfkY2z7204pLBF24fZhvha2gGRl76pTLTiTJd79Gr3HnmJByd1vFwKBgHL7 vfYuEeM55lT4Hz8sTAFtR2O/7+cvTgAQteSlZbfGXlp939DonUulhTkxsFc7/3wq unCpzcdoSWSTYDGqcf1FBIKKVVltg7EPeR0KBJIQabgCHqrLOBZojPZ7m5RJ+765 lysuxZvFuTFMPzNe2gssRf+JuBMt6tR+WclsxZYBAoGAEEFs1ppDil1xlP5rdH7T d3TSw/u4eU/X8Ei1zi25hdRUiV76fP9fBELYFmSrPBhugYv91vtSv/LmD4zLfLv/ yzwAD9j1lGbgM8Of8klCkk+XSJ88ryUwnMTJ5loQJW8t4L+zLv5Le7Ca9SAT0kJ1 jT0GzDymgLMGp8RPdBkpk+w= -----END PRIVATE KEY----- quic-go-0.59.0/internal/utils/000077500000000000000000000000001513066070600161315ustar00rootroot00000000000000quic-go-0.59.0/internal/utils/buffered_write_closer.go000066400000000000000000000007321513066070600230250ustar00rootroot00000000000000package utils import ( "bufio" "io" ) type bufferedWriteCloser struct { *bufio.Writer io.Closer } // NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser { return &bufferedWriteCloser{ Writer: writer, Closer: closer, } } func (h bufferedWriteCloser) Close() error { if err := h.Flush(); err != nil { return err } return h.Closer.Close() } quic-go-0.59.0/internal/utils/buffered_write_closer_test.go000066400000000000000000000007611513066070600240660ustar00rootroot00000000000000package utils import ( "bufio" "bytes" "testing" "github.com/stretchr/testify/require" ) type nopCloser struct{} func (nopCloser) Close() error { return nil } func TestBufferedWriteCloserFlushBeforeClosing(t *testing.T) { buf := &bytes.Buffer{} w := bufio.NewWriter(buf) wc := NewBufferedWriteCloser(w, &nopCloser{}) _, err := wc.Write([]byte("foobar")) require.NoError(t, err) require.Zero(t, buf.Len()) require.NoError(t, wc.Close()) require.Equal(t, "foobar", buf.String()) } quic-go-0.59.0/internal/utils/connstats.go000066400000000000000000000006021513066070600204720ustar00rootroot00000000000000package utils import "sync/atomic" // ConnectionStats stores stats for the connection. See the public // ConnectionStats struct in connection.go for more information type ConnectionStats struct { BytesSent atomic.Uint64 PacketsSent atomic.Uint64 BytesReceived atomic.Uint64 PacketsReceived atomic.Uint64 BytesLost atomic.Uint64 PacketsLost atomic.Uint64 } quic-go-0.59.0/internal/utils/linkedlist/000077500000000000000000000000001513066070600202735ustar00rootroot00000000000000quic-go-0.59.0/internal/utils/linkedlist/README.md000066400000000000000000000004471513066070600215570ustar00rootroot00000000000000# Usage This is the Go standard library implementation of a linked list (https://golang.org/src/container/list/list.go), with the following modifications: * it uses Go generics * it allows passing in a `sync.Pool` (via the `NewWithPool` constructor) to reduce allocations of `Element` structs quic-go-0.59.0/internal/utils/linkedlist/linkedlist.go000066400000000000000000000157071513066070600227760ustar00rootroot00000000000000// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package list implements a doubly linked list. // // To iterate over a list (where l is a *List[T]): // // for e := l.Front(); e != nil; e = e.Next() { // // do something with e.Value // } package list import "sync" func NewPool[T any]() *sync.Pool { return &sync.Pool{New: func() any { return &Element[T]{} }} } // Element is an element of a linked list. type Element[T any] struct { // Next and previous pointers in the doubly-linked list of elements. // To simplify the implementation, internally a list l is implemented // as a ring, such that &l.root is both the next element of the last // list element (l.Back()) and the previous element of the first list // element (l.Front()). next, prev *Element[T] // The list to which this element belongs. list *List[T] // The value stored with this element. Value T } // Next returns the next list element or nil. func (e *Element[T]) Next() *Element[T] { if p := e.next; e.list != nil && p != &e.list.root { return p } return nil } // Prev returns the previous list element or nil. func (e *Element[T]) Prev() *Element[T] { if p := e.prev; e.list != nil && p != &e.list.root { return p } return nil } func (e *Element[T]) List() *List[T] { return e.list } // List represents a doubly linked list. // The zero value for List is an empty list ready to use. type List[T any] struct { root Element[T] // sentinel list element, only &root, root.prev, and root.next are used len int // current list length excluding (this) sentinel element pool *sync.Pool } // Init initializes or clears list l. func (l *List[T]) Init() *List[T] { l.root.next = &l.root l.root.prev = &l.root l.len = 0 return l } // New returns an initialized list. func New[T any]() *List[T] { return new(List[T]).Init() } // NewWithPool returns an initialized list, using a sync.Pool for list elements. func NewWithPool[T any](pool *sync.Pool) *List[T] { l := &List[T]{pool: pool} return l.Init() } // Len returns the number of elements of list l. // The complexity is O(1). func (l *List[T]) Len() int { return l.len } // Front returns the first element of list l or nil if the list is empty. func (l *List[T]) Front() *Element[T] { if l.len == 0 { return nil } return l.root.next } // Back returns the last element of list l or nil if the list is empty. func (l *List[T]) Back() *Element[T] { if l.len == 0 { return nil } return l.root.prev } // lazyInit lazily initializes a zero List value. func (l *List[T]) lazyInit() { if l.root.next == nil { l.Init() } } // insert inserts e after at, increments l.len, and returns e. func (l *List[T]) insert(e, at *Element[T]) *Element[T] { e.prev = at e.next = at.next e.prev.next = e e.next.prev = e e.list = l l.len++ return e } // insertValue is a convenience wrapper for insert(&Element{Value: v}, at). func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { var e *Element[T] if l.pool != nil { e = l.pool.Get().(*Element[T]) } else { e = &Element[T]{} } e.Value = v return l.insert(e, at) } // remove removes e from its list, decrements l.len func (l *List[T]) remove(e *Element[T]) { e.prev.next = e.next e.next.prev = e.prev e.next = nil // avoid memory leaks e.prev = nil // avoid memory leaks e.list = nil if l.pool != nil { l.pool.Put(e) } l.len-- } // move moves e to next to at. func (l *List[T]) move(e, at *Element[T]) { if e == at { return } e.prev.next = e.next e.next.prev = e.prev e.prev = at e.next = at.next e.prev.next = e e.next.prev = e } // Remove removes e from l if e is an element of list l. // It returns the element value e.Value. // The element must not be nil. func (l *List[T]) Remove(e *Element[T]) T { v := e.Value if e.list == l { // if e.list == l, l must have been initialized when e was inserted // in l or l == nil (e is a zero Element) and l.remove will crash l.remove(e) } return v } // PushFront inserts a new element e with value v at the front of list l and returns e. func (l *List[T]) PushFront(v T) *Element[T] { l.lazyInit() return l.insertValue(v, &l.root) } // PushBack inserts a new element e with value v at the back of list l and returns e. func (l *List[T]) PushBack(v T) *Element[T] { l.lazyInit() return l.insertValue(v, l.root.prev) } // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { if mark.list != l { return nil } // see comment in List.Remove about initialization of l return l.insertValue(v, mark.prev) } // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { if mark.list != l { return nil } // see comment in List.Remove about initialization of l return l.insertValue(v, mark) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToFront(e *Element[T]) { if e.list != l || l.root.next == e { return } // see comment in List.Remove about initialization of l l.move(e, &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *List[T]) MoveToBack(e *Element[T]) { if e.list != l || l.root.prev == e { return } // see comment in List.Remove about initialization of l l.move(e, l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveBefore(e, mark *Element[T]) { if e.list != l || e == mark || mark.list != l { return } l.move(e, mark.prev) } // MoveAfter moves element e to its new position after mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. func (l *List[T]) MoveAfter(e, mark *Element[T]) { if e.list != l || e == mark || mark.list != l { return } l.move(e, mark) } // PushBackList inserts a copy of another list at the back of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushBackList(other *List[T]) { l.lazyInit() for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { l.insertValue(e.Value, l.root.prev) } } // PushFrontList inserts a copy of another list at the front of list l. // The lists l and other may be the same. They must not be nil. func (l *List[T]) PushFrontList(other *List[T]) { l.lazyInit() for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { l.insertValue(e.Value, &l.root) } } quic-go-0.59.0/internal/utils/log.go000066400000000000000000000053331513066070600172450ustar00rootroot00000000000000package utils import ( "fmt" "log" "os" "strings" "time" ) // LogLevel of quic-go type LogLevel uint8 const ( // LogLevelNothing disables LogLevelNothing LogLevel = iota // LogLevelError enables err logs LogLevelError // LogLevelInfo enables info logs (e.g. packets) LogLevelInfo // LogLevelDebug enables debug logs (e.g. packet contents) LogLevelDebug ) const logEnv = "QUIC_GO_LOG_LEVEL" // A Logger logs. type Logger interface { SetLogLevel(LogLevel) SetLogTimeFormat(format string) WithPrefix(prefix string) Logger Debug() bool Errorf(format string, args ...any) Infof(format string, args ...any) Debugf(format string, args ...any) } // DefaultLogger is used by quic-go for logging. var DefaultLogger Logger type defaultLogger struct { prefix string logLevel LogLevel timeFormat string } var _ Logger = &defaultLogger{} // SetLogLevel sets the log level func (l *defaultLogger) SetLogLevel(level LogLevel) { l.logLevel = level } // SetLogTimeFormat sets the format of the timestamp // an empty string disables the logging of timestamps func (l *defaultLogger) SetLogTimeFormat(format string) { log.SetFlags(0) // disable timestamp logging done by the log package l.timeFormat = format } // Debugf logs something func (l *defaultLogger) Debugf(format string, args ...any) { if l.logLevel == LogLevelDebug { l.logMessage(format, args...) } } // Infof logs something func (l *defaultLogger) Infof(format string, args ...any) { if l.logLevel >= LogLevelInfo { l.logMessage(format, args...) } } // Errorf logs something func (l *defaultLogger) Errorf(format string, args ...any) { if l.logLevel >= LogLevelError { l.logMessage(format, args...) } } func (l *defaultLogger) logMessage(format string, args ...any) { var pre string if len(l.timeFormat) > 0 { pre = time.Now().Format(l.timeFormat) + " " } if len(l.prefix) > 0 { pre += l.prefix + " " } log.Printf(pre+format, args...) } func (l *defaultLogger) WithPrefix(prefix string) Logger { if len(l.prefix) > 0 { prefix = l.prefix + " " + prefix } return &defaultLogger{ logLevel: l.logLevel, timeFormat: l.timeFormat, prefix: prefix, } } // Debug returns true if the log level is LogLevelDebug func (l *defaultLogger) Debug() bool { return l.logLevel == LogLevelDebug } func init() { DefaultLogger = &defaultLogger{} DefaultLogger.SetLogLevel(readLoggingEnv()) } func readLoggingEnv() LogLevel { switch strings.ToLower(os.Getenv(logEnv)) { case "": return LogLevelNothing case "debug": return LogLevelDebug case "info": return LogLevelInfo case "error": return LogLevelError default: fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/quic-go/quic-go/wiki/Logging") return LogLevelNothing } } quic-go-0.59.0/internal/utils/log_test.go000066400000000000000000000073751513066070600203140ustar00rootroot00000000000000package utils import ( "bytes" "log" "os" "testing" "time" "github.com/stretchr/testify/require" ) func TestLogLevelNothing(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) DefaultLogger.SetLogLevel(LogLevelNothing) DefaultLogger.Debugf("debug") DefaultLogger.Infof("info") DefaultLogger.Errorf("err") require.Empty(t, b.String()) } func TestLogLevelError(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) DefaultLogger.SetLogLevel(LogLevelError) DefaultLogger.Debugf("debug") DefaultLogger.Infof("info") DefaultLogger.Errorf("err") require.Contains(t, b.String(), "err\n") require.NotContains(t, b.String(), "info") require.NotContains(t, b.String(), "debug") } func TestLogLevelInfo(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) DefaultLogger.SetLogLevel(LogLevelInfo) DefaultLogger.Debugf("debug") DefaultLogger.Infof("info") DefaultLogger.Errorf("err") require.Contains(t, b.String(), "err\n") require.Contains(t, b.String(), "info\n") require.NotContains(t, b.String(), "debug") } func TestLogLevelDebug(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) require.False(t, DefaultLogger.Debug()) DefaultLogger.SetLogLevel(LogLevelDebug) require.True(t, DefaultLogger.Debug()) DefaultLogger.Debugf("debug") DefaultLogger.Infof("info") DefaultLogger.Errorf("err") require.Contains(t, b.String(), "err\n") require.Contains(t, b.String(), "info\n") require.Contains(t, b.String(), "debug\n") } func TestNoTimestampWithEmptyFormat(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) DefaultLogger.SetLogLevel(LogLevelDebug) DefaultLogger.SetLogTimeFormat("") DefaultLogger.Debugf("debug") require.Equal(t, "debug\n", b.String()) } func TestAddTimestamp(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) format := "Jan 2, 2006" DefaultLogger.SetLogTimeFormat(format) DefaultLogger.SetLogLevel(LogLevelInfo) DefaultLogger.Infof("info") timestamp := b.String()[:b.Len()-6] parsedTime, err := time.ParseInLocation(format, timestamp, time.Local) require.NoError(t, err) require.WithinDuration(t, time.Now(), parsedTime, 25*time.Hour) } func TestLogAddPrefixes(t *testing.T) { b := &bytes.Buffer{} log.SetOutput(b) defer log.SetOutput(os.Stdout) defer DefaultLogger.SetLogLevel(LogLevelNothing) DefaultLogger.SetLogLevel(LogLevelDebug) // single prefix prefixLogger := DefaultLogger.WithPrefix("prefix") prefixLogger.Debugf("debug1") require.Contains(t, b.String(), "prefix") require.Contains(t, b.String(), "debug1") // multiple prefixes b.Reset() prefixLogger1 := DefaultLogger.WithPrefix("prefix1") prefixLogger2 := prefixLogger1.WithPrefix("prefix2") prefixLogger2.Debugf("debug2") require.Contains(t, b.String(), "prefix1") require.Contains(t, b.String(), "prefix2") require.Contains(t, b.String(), "debug2") } func TestLogLevelFromEnv(t *testing.T) { testCases := []struct { envValue string expected LogLevel }{ {"DEBUG", LogLevelDebug}, {"debug", LogLevelDebug}, {"INFO", LogLevelInfo}, {"ERROR", LogLevelError}, } for _, tc := range testCases { t.Setenv(logEnv, tc.envValue) require.Equal(t, tc.expected, readLoggingEnv()) } // invalid values t.Setenv(logEnv, "") require.Equal(t, LogLevelNothing, readLoggingEnv()) t.Setenv(logEnv, "asdf") require.Equal(t, LogLevelNothing, readLoggingEnv()) } quic-go-0.59.0/internal/utils/rand.go000066400000000000000000000011571513066070600174100ustar00rootroot00000000000000package utils import ( "crypto/rand" "encoding/binary" ) // Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand. type Rand struct { buf [4]byte } func (r *Rand) Int31() int32 { rand.Read(r.buf[:]) return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31)) } // copied from the standard library math/rand implementation of Int63n func (r *Rand) Int31n(n int32) int32 { if n&(n-1) == 0 { // n is power of two, can mask return r.Int31() & (n - 1) } max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) v := r.Int31() for v > max { v = r.Int31() } return v % n } quic-go-0.59.0/internal/utils/rand_test.go000066400000000000000000000010631513066070600204430ustar00rootroot00000000000000package utils import ( "testing" "github.com/stretchr/testify/require" ) func TestRandomNumbers(t *testing.T) { const ( num = 1000 max = 12345678 ) var values [num]int32 var r Rand for i := 0; i < num; i++ { v := r.Int31n(max) require.GreaterOrEqual(t, v, int32(0)) require.Less(t, v, int32(max)) values[i] = v } var sum uint64 for _, n := range values { sum += uint64(n) } average := float64(sum) / num expectedAverage := float64(max) / 2 tolerance := float64(max) / 25 require.InDelta(t, expectedAverage, average, tolerance) } quic-go-0.59.0/internal/utils/ringbuffer/000077500000000000000000000000001513066070600202625ustar00rootroot00000000000000quic-go-0.59.0/internal/utils/ringbuffer/ringbuffer.go000066400000000000000000000045461513066070600227530ustar00rootroot00000000000000package ringbuffer // A RingBuffer is a ring buffer. // It acts as a heap that doesn't cause any allocations. type RingBuffer[T any] struct { ring []T headPos, tailPos int full bool } // Init preallocates a buffer with a certain size. func (r *RingBuffer[T]) Init(size int) { r.ring = make([]T, size) } // Len returns the number of elements in the ring buffer. func (r *RingBuffer[T]) Len() int { if r.full { return len(r.ring) } if r.tailPos >= r.headPos { return r.tailPos - r.headPos } return r.tailPos - r.headPos + len(r.ring) } // Empty says if the ring buffer is empty. func (r *RingBuffer[T]) Empty() bool { return !r.full && r.headPos == r.tailPos } // PushBack adds a new element. // If the ring buffer is full, its capacity is increased first. func (r *RingBuffer[T]) PushBack(t T) { if r.full || len(r.ring) == 0 { r.grow() } r.ring[r.tailPos] = t r.tailPos++ if r.tailPos == len(r.ring) { r.tailPos = 0 } if r.tailPos == r.headPos { r.full = true } } // PopFront returns the next element. // It must not be called when the buffer is empty, that means that // callers might need to check if there are elements in the buffer first. func (r *RingBuffer[T]) PopFront() T { if r.Empty() { panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") } r.full = false t := r.ring[r.headPos] r.ring[r.headPos] = *new(T) r.headPos++ if r.headPos == len(r.ring) { r.headPos = 0 } return t } // PeekFront returns the next element. // It must not be called when the buffer is empty, that means that // callers might need to check if there are elements in the buffer first. func (r *RingBuffer[T]) PeekFront() T { if r.Empty() { panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: peek from an empty queue") } return r.ring[r.headPos] } // Grow the maximum size of the queue. // This method assume the queue is full. func (r *RingBuffer[T]) grow() { oldRing := r.ring newSize := len(oldRing) * 2 if newSize == 0 { newSize = 1 } r.ring = make([]T, newSize) headLen := copy(r.ring, oldRing[r.headPos:]) copy(r.ring[headLen:], oldRing[:r.headPos]) r.headPos, r.tailPos, r.full = 0, len(oldRing), false } // Clear removes all elements. func (r *RingBuffer[T]) Clear() { var zeroValue T for i := range r.ring { r.ring[i] = zeroValue } r.headPos, r.tailPos, r.full = 0, 0, false } quic-go-0.59.0/internal/utils/ringbuffer/ringbuffer_bench_test.go000066400000000000000000000002631513066070600251410ustar00rootroot00000000000000package ringbuffer import "testing" func BenchmarkRingBuffer(b *testing.B) { r := RingBuffer[int]{} var val int for b.Loop() { r.PushBack(val) r.PopFront() val++ } } quic-go-0.59.0/internal/utils/ringbuffer/ringbuffer_test.go000066400000000000000000000021141513066070600237770ustar00rootroot00000000000000package ringbuffer import ( "testing" "github.com/stretchr/testify/require" ) func TestPushPeekPop(t *testing.T) { r := RingBuffer[int]{} require.Equal(t, 0, len(r.ring)) require.Panics(t, func() { r.PopFront() }) r.PushBack(1) r.PushBack(2) r.PushBack(3) require.Equal(t, 1, r.PeekFront()) require.Equal(t, 1, r.PeekFront()) require.Equal(t, 1, r.PopFront()) require.Equal(t, 2, r.PeekFront()) require.Equal(t, 2, r.PopFront()) r.PushBack(4) r.PushBack(5) require.Equal(t, 3, r.Len()) r.PushBack(6) require.Equal(t, 4, r.Len()) require.Equal(t, 3, r.PopFront()) require.Equal(t, 4, r.PopFront()) require.Equal(t, 5, r.PopFront()) require.Equal(t, 6, r.PopFront()) } func TestPanicOnEmptyBuffer(t *testing.T) { r := RingBuffer[string]{} require.True(t, r.Empty()) require.Zero(t, r.Len()) require.Panics(t, func() { r.PeekFront() }) require.Panics(t, func() { r.PopFront() }) } func TestClear(t *testing.T) { r := RingBuffer[int]{} r.Init(2) r.PushBack(1) r.PushBack(2) require.True(t, r.full) r.Clear() require.False(t, r.full) require.Equal(t, 0, r.Len()) } quic-go-0.59.0/internal/utils/rtt_stats.go000066400000000000000000000115051513066070600205110ustar00rootroot00000000000000package utils import ( "sync/atomic" "time" "github.com/quic-go/quic-go/internal/protocol" ) const ( rttAlpha = 0.125 oneMinusAlpha = 1 - rttAlpha rttBeta = 0.25 oneMinusBeta = 1 - rttBeta ) // The default RTT used before an RTT sample is taken const DefaultInitialRTT = 100 * time.Millisecond // RTTStats provides round-trip statistics type RTTStats struct { hasMeasurement bool minRTT atomic.Int64 // nanoseconds latestRTT atomic.Int64 // nanoseconds smoothedRTT atomic.Int64 // nanoseconds meanDeviation atomic.Int64 // nanoseconds maxAckDelay atomic.Int64 // nanoseconds } func NewRTTStats() *RTTStats { var rttStats RTTStats rttStats.minRTT.Store(DefaultInitialRTT.Nanoseconds()) rttStats.latestRTT.Store(DefaultInitialRTT.Nanoseconds()) rttStats.smoothedRTT.Store(DefaultInitialRTT.Nanoseconds()) return &rttStats } // MinRTT Returns the minRTT for the entire connection. // May return Zero if no valid updates have occurred. func (r *RTTStats) MinRTT() time.Duration { return time.Duration(r.minRTT.Load()) } // LatestRTT returns the most recent rtt measurement. // May return Zero if no valid updates have occurred. func (r *RTTStats) LatestRTT() time.Duration { return time.Duration(r.latestRTT.Load()) } // SmoothedRTT returns the smoothed RTT for the connection. // May return Zero if no valid updates have occurred. func (r *RTTStats) SmoothedRTT() time.Duration { return time.Duration(r.smoothedRTT.Load()) } // MeanDeviation gets the mean deviation func (r *RTTStats) MeanDeviation() time.Duration { return time.Duration(r.meanDeviation.Load()) } // MaxAckDelay gets the max_ack_delay advertised by the peer func (r *RTTStats) MaxAckDelay() time.Duration { return time.Duration(r.maxAckDelay.Load()) } // PTO gets the probe timeout duration. func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { if !r.hasMeasurement { return 2 * DefaultInitialRTT } pto := r.SmoothedRTT() + max(4*r.MeanDeviation(), protocol.TimerGranularity) if includeMaxAckDelay { pto += r.MaxAckDelay() } return pto } // UpdateRTT updates the RTT based on a new sample. func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration) { if sendDelta <= 0 { return } // Update r.minRTT first. r.minRTT does not use an rttSample corrected for // ackDelay but the raw observed sendDelta, since poor clock granularity at // the client may cause a high ackDelay to result in underestimation of the // r.minRTT. minRTT := time.Duration(r.minRTT.Load()) if !r.hasMeasurement || minRTT > sendDelta { minRTT = sendDelta r.minRTT.Store(sendDelta.Nanoseconds()) } // Correct for ackDelay if information received from the peer results in a // an RTT sample at least as large as minRTT. Otherwise, only use the // sendDelta. sample := sendDelta if sample-minRTT >= ackDelay { sample -= ackDelay } r.latestRTT.Store(sample.Nanoseconds()) // First time call. if !r.hasMeasurement { r.hasMeasurement = true r.smoothedRTT.Store(sample.Nanoseconds()) r.meanDeviation.Store(sample.Nanoseconds() / 2) } else { smoothedRTT := r.SmoothedRTT() meanDev := time.Duration(oneMinusBeta*float32(r.MeanDeviation()/time.Microsecond)+rttBeta*float32((smoothedRTT-sample).Abs()/time.Microsecond)) * time.Microsecond newSmoothedRTT := time.Duration((float32(smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond r.meanDeviation.Store(meanDev.Nanoseconds()) r.smoothedRTT.Store(newSmoothedRTT.Nanoseconds()) } } func (r *RTTStats) HasMeasurement() bool { return r.hasMeasurement } // SetMaxAckDelay sets the max_ack_delay func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { r.maxAckDelay.Store(int64(mad)) } // SetInitialRTT sets the initial RTT. // It is used during handshake when restoring the RTT stats from the token. func (r *RTTStats) SetInitialRTT(t time.Duration) { // On the server side, by the time we get to process the session ticket, // we might already have obtained an RTT measurement. // This can happen if we received the ClientHello in multiple pieces, and one of those pieces was lost. // Discard the restored value. A fresh measurement is always better. if r.hasMeasurement { return } r.smoothedRTT.Store(int64(t)) r.latestRTT.Store(int64(t)) } func (r *RTTStats) ResetForPathMigration() { r.hasMeasurement = false r.minRTT.Store(DefaultInitialRTT.Nanoseconds()) r.latestRTT.Store(DefaultInitialRTT.Nanoseconds()) r.smoothedRTT.Store(DefaultInitialRTT.Nanoseconds()) r.meanDeviation.Store(0) // max_ack_delay remains valid } func (r *RTTStats) Clone() *RTTStats { out := &RTTStats{} out.hasMeasurement = r.hasMeasurement out.minRTT.Store(r.minRTT.Load()) out.latestRTT.Store(r.latestRTT.Load()) out.smoothedRTT.Store(r.smoothedRTT.Load()) out.meanDeviation.Store(r.meanDeviation.Load()) out.maxAckDelay.Store(r.maxAckDelay.Load()) return out } quic-go-0.59.0/internal/utils/rtt_stats_test.go000066400000000000000000000124401513066070600215470ustar00rootroot00000000000000package utils import ( "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestRTTStatsDefaults(t *testing.T) { rttStats := NewRTTStats() require.False(t, rttStats.HasMeasurement()) require.Equal(t, DefaultInitialRTT, rttStats.MinRTT()) require.Equal(t, DefaultInitialRTT, rttStats.SmoothedRTT()) } func TestRTTStatsSmoothedRTT(t *testing.T) { rttStats := NewRTTStats() require.False(t, rttStats.HasMeasurement()) // verify that ack_delay is ignored in the first measurement rttStats.UpdateRTT(300*time.Millisecond, 100*time.Millisecond) require.True(t, rttStats.HasMeasurement()) require.Equal(t, 300*time.Millisecond, rttStats.LatestRTT()) require.Equal(t, 300*time.Millisecond, rttStats.SmoothedRTT()) // verify that smoothed RTT includes max ack delay if it's reasonable rttStats.UpdateRTT(350*time.Millisecond, 50*time.Millisecond) require.Equal(t, 300*time.Millisecond, rttStats.LatestRTT()) require.Equal(t, 300*time.Millisecond, rttStats.SmoothedRTT()) // verify that large erroneous ack_delay does not change smoothed RTT rttStats.UpdateRTT(200*time.Millisecond, 300*time.Millisecond) require.Equal(t, 200*time.Millisecond, rttStats.LatestRTT()) require.Equal(t, 287500*time.Microsecond, rttStats.SmoothedRTT()) } func TestRTTStatsMinRTT(t *testing.T) { rttStats := NewRTTStats() rttStats.UpdateRTT(200*time.Millisecond, 0) require.Equal(t, 200*time.Millisecond, rttStats.MinRTT()) rttStats.UpdateRTT(10*time.Millisecond, 0) require.Equal(t, 10*time.Millisecond, rttStats.MinRTT()) rttStats.UpdateRTT(50*time.Millisecond, 0) require.Equal(t, 10*time.Millisecond, rttStats.MinRTT()) rttStats.UpdateRTT(50*time.Millisecond, 0) require.Equal(t, 10*time.Millisecond, rttStats.MinRTT()) rttStats.UpdateRTT(50*time.Millisecond, 0) require.Equal(t, 10*time.Millisecond, rttStats.MinRTT()) // verify that ack_delay does not go into recording of MinRTT rttStats.UpdateRTT(7*time.Millisecond, 2*time.Millisecond) require.Equal(t, 7*time.Millisecond, rttStats.MinRTT()) } func TestRTTStatsMaxAckDelay(t *testing.T) { rttStats := NewRTTStats() rttStats.SetMaxAckDelay(42 * time.Minute) require.Equal(t, 42*time.Minute, rttStats.MaxAckDelay()) } func TestRTTStatsComputePTO(t *testing.T) { const ( maxAckDelay = 42 * time.Minute rtt = time.Second ) rttStats := NewRTTStats() rttStats.SetMaxAckDelay(maxAckDelay) rttStats.UpdateRTT(rtt, 0) require.Equal(t, rtt, rttStats.SmoothedRTT()) require.Equal(t, rtt/2, rttStats.MeanDeviation()) require.Equal(t, rtt+4*(rtt/2), rttStats.PTO(false)) require.Equal(t, rtt+4*(rtt/2)+maxAckDelay, rttStats.PTO(true)) } func TestRTTStatsPTOWithShortRTT(t *testing.T) { const rtt = time.Microsecond rttStats := NewRTTStats() rttStats.UpdateRTT(rtt, 0) require.Equal(t, rtt+protocol.TimerGranularity, rttStats.PTO(true)) } func TestRTTStatsUpdateWithBadSendDeltas(t *testing.T) { rttStats := NewRTTStats() const initialRtt = 10 * time.Millisecond rttStats.UpdateRTT(initialRtt, 0) require.Equal(t, initialRtt, rttStats.MinRTT()) require.Equal(t, initialRtt, rttStats.SmoothedRTT()) badSendDeltas := []time.Duration{ 0, -1000 * time.Microsecond, } for _, badSendDelta := range badSendDeltas { rttStats.UpdateRTT(badSendDelta, 0) require.Equal(t, initialRtt, rttStats.MinRTT()) require.Equal(t, initialRtt, rttStats.SmoothedRTT()) } } func TestRTTStatsRestore(t *testing.T) { rttStats := NewRTTStats() rttStats.SetInitialRTT(10 * time.Second) require.Equal(t, 10*time.Second, rttStats.LatestRTT()) require.Equal(t, 10*time.Second, rttStats.SmoothedRTT()) require.Zero(t, rttStats.MeanDeviation()) // update the RTT and make sure that the initial value is immediately forgotten rttStats.UpdateRTT(200*time.Millisecond, 0) require.Equal(t, 200*time.Millisecond, rttStats.LatestRTT()) require.Equal(t, 200*time.Millisecond, rttStats.SmoothedRTT()) require.Equal(t, 100*time.Millisecond, rttStats.MeanDeviation()) } func TestRTTMeasurementAfterRestore(t *testing.T) { rttStats := NewRTTStats() const rtt = 10 * time.Millisecond rttStats.UpdateRTT(rtt, 0) require.Equal(t, rtt, rttStats.LatestRTT()) require.Equal(t, rtt, rttStats.SmoothedRTT()) rttStats.SetInitialRTT(time.Minute) require.Equal(t, rtt, rttStats.LatestRTT()) require.Equal(t, rtt, rttStats.SmoothedRTT()) } func TestRTTStatsResetForPathMigration(t *testing.T) { rttStats := NewRTTStats() rttStats.SetMaxAckDelay(42 * time.Millisecond) rttStats.UpdateRTT(time.Second, 0) rttStats.UpdateRTT(10*time.Second, 0) require.True(t, rttStats.HasMeasurement()) require.Equal(t, time.Second, rttStats.MinRTT()) require.Equal(t, 10*time.Second, rttStats.LatestRTT()) require.NotZero(t, rttStats.SmoothedRTT()) rttStats.ResetForPathMigration() require.False(t, rttStats.HasMeasurement()) require.Equal(t, DefaultInitialRTT, rttStats.MinRTT()) require.Equal(t, DefaultInitialRTT, rttStats.LatestRTT()) require.Equal(t, DefaultInitialRTT, rttStats.SmoothedRTT()) require.Equal(t, 2*DefaultInitialRTT, rttStats.PTO(false)) // make sure that max_ack_delay was not reset require.Equal(t, 42*time.Millisecond, rttStats.MaxAckDelay()) rttStats.UpdateRTT(10*time.Millisecond, 0) require.True(t, rttStats.HasMeasurement()) require.Equal(t, 10*time.Millisecond, rttStats.SmoothedRTT()) require.Equal(t, 10*time.Millisecond, rttStats.LatestRTT()) } quic-go-0.59.0/internal/wire/000077500000000000000000000000001513066070600157375ustar00rootroot00000000000000quic-go-0.59.0/internal/wire/ack_frame.go000066400000000000000000000203511513066070600201770ustar00rootroot00000000000000package wire import ( "errors" "math" "sort" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") // An AckFrame is an ACK frame type AckFrame struct { AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last DelayTime time.Duration ECT0, ECT1, ECNCE uint64 } // parseAckFrame reads an ACK frame func parseAckFrame(frame *AckFrame, b []byte, typ FrameType, ackDelayExponent uint8, _ protocol.Version) (int, error) { startLen := len(b) ecn := typ == FrameTypeAckECN la, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] largestAcked := protocol.PacketNumber(la) delay, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] delayTime := time.Duration(delay*1< largestAcked { return 0, errors.New("invalid first ACK range") } smallest := largestAcked - ackBlock frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) // read all the other ACK ranges for range numBlocks { g, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] gap := protocol.PacketNumber(g) if smallest < gap+2 { return 0, errInvalidAckRanges } largest := smallest - gap - 2 ab, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] ackBlock := protocol.PacketNumber(ab) if ackBlock > largest { return 0, errInvalidAckRanges } smallest = largest - ackBlock frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) } if !frame.validateAckRanges() { return 0, errInvalidAckRanges } if ecn { ect0, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] frame.ECT0 = ect0 ect1, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] frame.ECT1 = ect1 ecnce, l, err := quicvarint.Parse(b) if err != nil { return 0, replaceUnexpectedEOF(err) } b = b[l:] frame.ECNCE = ecnce } return startLen - len(b), nil } // Append appends an ACK frame. func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { b = append(b, byte(FrameTypeAckECN)) } else { b = append(b, byte(FrameTypeAck)) } b = quicvarint.Append(b, uint64(f.LargestAcked())) b = quicvarint.Append(b, encodeAckDelay(f.DelayTime)) numRanges := min(len(f.AckRanges), protocol.MaxNumAckRanges) b = quicvarint.Append(b, uint64(numRanges-1)) // write the first range _, firstRange := f.encodeAckRange(0) b = quicvarint.Append(b, firstRange) // write all the other range for i := 1; i < numRanges; i++ { gap, len := f.encodeAckRange(i) b = quicvarint.Append(b, gap) b = quicvarint.Append(b, len) } if hasECN { b = quicvarint.Append(b, f.ECT0) b = quicvarint.Append(b, f.ECT1) b = quicvarint.Append(b, f.ECNCE) } return b, nil } // Length of a written frame func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount { largestAcked := f.AckRanges[0].Largest // The number of ACK ranges is limited to 64, which guarantees that the // ACK Range Count value can be encoded in a single byte varint. length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + 1 lowestInFirstRange := f.AckRanges[0].Smallest length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange)) for i := 1; i < min(len(f.AckRanges), protocol.MaxNumAckRanges); i++ { gap, len := f.encodeAckRange(i) length += quicvarint.Len(gap) length += quicvarint.Len(len) } if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { length += quicvarint.Len(f.ECT0) + quicvarint.Len(f.ECT1) + quicvarint.Len(f.ECNCE) } return protocol.ByteCount(length) } // Truncate truncates the ACK frame to fit into maxSize, // and to at most 64 ACK ranges. // maxSize must be large enough to fit at least one ACK range. func (f *AckFrame) Truncate(maxSize protocol.ByteCount, _ protocol.Version) { f.AckRanges = f.AckRanges[:f.numEncodableAckRanges(maxSize)] } // gets the number of ACK ranges that can be encoded // such that the resulting frame is smaller than maxSize func (f *AckFrame) numEncodableAckRanges(maxSize protocol.ByteCount) int { // Fast path: Most ACK frames are relatively small, and we don't need to calculate the exact length. // We just assume the worst case scenario: every varint is encoded to 8 bytes. // If the result is still smaller than the maximum ACK frame size, the actual ACK frame will definitely fit. length := 1 + 8 /* largest acked */ + 8 /* delay */ + 1 /* ack range count */ + 8 /* first range */ if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { length += 8 + 8 + 8 } numRanges := min(len(f.AckRanges), protocol.MaxNumAckRanges) length += 2 * 8 * (numRanges - 1) if protocol.ByteCount(length) <= maxSize { return numRanges } // Slow path: Calculate the exact length of the ACK frame. length = 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + 1 _, firstRange := f.encodeAckRange(0) length += quicvarint.Len(firstRange) if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { length += quicvarint.Len(f.ECT0) + quicvarint.Len(f.ECT1) + quicvarint.Len(f.ECNCE) } for i := 1; i < numRanges; i++ { gap, l := f.encodeAckRange(i) rangeLen := quicvarint.Len(gap) + quicvarint.Len(l) if protocol.ByteCount(length+rangeLen) > maxSize { // Writing range i would exceed the maximum size, // so encode one range less than that. return i } length += rangeLen } return numRanges } func (f *AckFrame) encodeAckRange(i int) (gap, length uint64) { if i == 0 { return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) } return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) } // HasMissingRanges returns if this frame reports any missing packets func (f *AckFrame) HasMissingRanges() bool { return len(f.AckRanges) > 1 } func (f *AckFrame) validateAckRanges() bool { if len(f.AckRanges) == 0 { return false } // check the validity of every single ACK range for _, ackRange := range f.AckRanges { if ackRange.Smallest > ackRange.Largest { return false } } // check the consistency for ACK with multiple ACK ranges for i, ackRange := range f.AckRanges { if i == 0 { continue } lastAckRange := f.AckRanges[i-1] if lastAckRange.Smallest <= ackRange.Smallest { return false } if lastAckRange.Smallest <= ackRange.Largest+1 { return false } } return true } // LargestAcked is the largest acked packet number func (f *AckFrame) LargestAcked() protocol.PacketNumber { return f.AckRanges[0].Largest } // LowestAcked is the lowest acked packet number func (f *AckFrame) LowestAcked() protocol.PacketNumber { return f.AckRanges[len(f.AckRanges)-1].Smallest } // AcksPacket determines if this ACK frame acks a certain packet number func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { if p < f.LowestAcked() || p > f.LargestAcked() { return false } i := sort.Search(len(f.AckRanges), func(i int) bool { return p >= f.AckRanges[i].Smallest }) // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked return p <= f.AckRanges[i].Largest } func (f *AckFrame) Reset() { f.DelayTime = 0 f.ECT0 = 0 f.ECT1 = 0 f.ECNCE = 0 for _, r := range f.AckRanges { r.Largest = 0 r.Smallest = 0 } f.AckRanges = f.AckRanges[:0] } func encodeAckDelay(delay time.Duration) uint64 { return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) } quic-go-0.59.0/internal/wire/ack_frame_test.go000066400000000000000000000461231513066070600212430ustar00rootroot00000000000000package wire import ( "io" "math" "slices" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestParseACKWithoutRanges(t *testing.T) { data := encodeVarInt(100) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(10)...) // first ack block var frame AckFrame n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked()) require.Equal(t, protocol.PacketNumber(90), frame.LowestAcked()) require.False(t, frame.HasMissingRanges()) } func TestParseACKSinglePacket(t *testing.T) { data := encodeVarInt(55) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block var frame AckFrame n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(55), frame.LargestAcked()) require.Equal(t, protocol.PacketNumber(55), frame.LowestAcked()) require.False(t, frame.HasMissingRanges()) } func TestParseACKAllPacketsFrom0ToLargest(t *testing.T) { data := encodeVarInt(20) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(20)...) // first ack block var frame AckFrame n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(20), frame.LargestAcked()) require.Equal(t, protocol.PacketNumber(0), frame.LowestAcked()) require.False(t, frame.HasMissingRanges()) } func TestParseACKRejectFirstBlockLargerThanLargestAcked(t *testing.T) { data := encodeVarInt(20) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(21)...) // first ack block var frame AckFrame _, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.EqualError(t, err, "invalid first ACK range") } func TestParseACKWithSingleBlock(t *testing.T) { data := encodeVarInt(1000) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(1)...) // num blocks data = append(data, encodeVarInt(100)...) // first ack block data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block var frame AckFrame n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(1000), frame.LargestAcked()) require.Equal(t, protocol.PacketNumber(750), frame.LowestAcked()) require.True(t, frame.HasMissingRanges()) require.Equal(t, []AckRange{ {Largest: 1000, Smallest: 900}, {Largest: 800, Smallest: 750}, }, frame.AckRanges) } func TestParseACKWithMultipleBlocks(t *testing.T) { data := encodeVarInt(100) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(2)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block data = append(data, encodeVarInt(0)...) // gap data = append(data, encodeVarInt(0)...) // ack block data = append(data, encodeVarInt(1)...) // gap data = append(data, encodeVarInt(1)...) // ack block var frame AckFrame n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked()) require.Equal(t, protocol.PacketNumber(94), frame.LowestAcked()) require.True(t, frame.HasMissingRanges()) require.Equal(t, []AckRange{ {Largest: 100, Smallest: 100}, {Largest: 98, Smallest: 98}, {Largest: 95, Smallest: 94}, }, frame.AckRanges) } func TestParseACKUseAckDelayExponent(t *testing.T) { const delayTime = 1 << 10 * time.Millisecond f := &AckFrame{ AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: delayTime, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) for i := uint8(0); i < 8; i++ { typ, l, err := quicvarint.Parse(b) require.NoError(t, err) var frame AckFrame n, err := parseAckFrame(&frame, b[l:], FrameType(typ), protocol.AckDelayExponent+i, protocol.Version1) require.NoError(t, err) require.Equal(t, len(b[l:]), n) require.Equal(t, delayTime*(1< len(b) { return nil, 0, io.EOF } reasonPhrase := make([]byte, reasonPhraseLen) copy(reasonPhrase, b) f.ReasonPhrase = string(reasonPhrase) return f, startLen - len(b) + int(reasonPhraseLen), nil } // Length of a written frame func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount { length := 1 + protocol.ByteCount(quicvarint.Len(f.ErrorCode)+quicvarint.Len(uint64(len(f.ReasonPhrase)))) + protocol.ByteCount(len(f.ReasonPhrase)) if !f.IsApplicationError { length += protocol.ByteCount(quicvarint.Len(f.FrameType)) // for the frame type } return length } func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if f.IsApplicationError { b = append(b, byte(FrameTypeApplicationClose)) } else { b = append(b, byte(FrameTypeConnectionClose)) } b = quicvarint.Append(b, f.ErrorCode) if !f.IsApplicationError { b = quicvarint.Append(b, f.FrameType) } b = quicvarint.Append(b, uint64(len(f.ReasonPhrase))) b = append(b, []byte(f.ReasonPhrase)...) return b, nil } quic-go-0.59.0/internal/wire/connection_close_frame_test.go000066400000000000000000000114131513066070600240230ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseConnectionCloseTransportError(t *testing.T) { reason := "No recent network activity." data := encodeVarInt(0x19) data = append(data, encodeVarInt(0x1337)...) // frame type data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, []byte(reason)...) frame, l, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1) require.NoError(t, err) require.False(t, frame.IsApplicationError) require.EqualValues(t, 0x19, frame.ErrorCode) require.EqualValues(t, 0x1337, frame.FrameType) require.Equal(t, reason, frame.ReasonPhrase) require.Equal(t, len(data), l) } func TestParseConnectionCloseWithApplicationError(t *testing.T) { reason := "The application messed things up." data := encodeVarInt(0xcafe) data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, reason...) frame, l, err := parseConnectionCloseFrame(data, FrameTypeApplicationClose, protocol.Version1) require.NoError(t, err) require.True(t, frame.IsApplicationError) require.EqualValues(t, 0xcafe, frame.ErrorCode) require.Equal(t, reason, frame.ReasonPhrase) require.Equal(t, len(data), l) } func TestParseConnectionCloseLongReasonPhrase(t *testing.T) { data := encodeVarInt(0xcafe) data = append(data, encodeVarInt(0x42)...) // frame type data = append(data, encodeVarInt(0xffff)...) // reason phrase length _, _, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1) require.Equal(t, io.EOF, err) } func TestParseConnectionCloseErrorsOnEOFs(t *testing.T) { reason := "No recent network activity." data := encodeVarInt(0x19) data = append(data, encodeVarInt(0x1337)...) // frame type data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, []byte(reason)...) _, l, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1) require.Equal(t, len(data), l) require.NoError(t, err) for i := range data { _, _, err = parseConnectionCloseFrame(data[:i], FrameTypeConnectionClose, protocol.Version1) require.Equal(t, io.EOF, err) } } func TestParseConnectionCloseNoReasonPhrase(t *testing.T) { data := encodeVarInt(0xcafe) data = append(data, encodeVarInt(0x42)...) // frame type data = append(data, encodeVarInt(0)...) frame, l, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1) require.NoError(t, err) require.Empty(t, frame.ReasonPhrase) require.Equal(t, len(data), l) } func TestWriteConnectionCloseNoReasonPhrase(t *testing.T) { frame := &ConnectionCloseFrame{ ErrorCode: 0xbeef, FrameType: 0x12345, } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeConnectionClose)} expected = append(expected, encodeVarInt(0xbeef)...) expected = append(expected, encodeVarInt(0x12345)...) // frame type expected = append(expected, encodeVarInt(0)...) // reason phrase length require.Equal(t, expected, b) } func TestWriteConnectionCloseWithReasonPhrase(t *testing.T) { frame := &ConnectionCloseFrame{ ErrorCode: 0xdead, ReasonPhrase: "foobar", } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeConnectionClose)} expected = append(expected, encodeVarInt(0xdead)...) expected = append(expected, encodeVarInt(0)...) // frame type expected = append(expected, encodeVarInt(6)...) // reason phrase length expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) } func TestWriteConnectionCloseWithApplicationError(t *testing.T) { frame := &ConnectionCloseFrame{ IsApplicationError: true, ErrorCode: 0xdead, ReasonPhrase: "foobar", } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeApplicationClose)} expected = append(expected, encodeVarInt(0xdead)...) expected = append(expected, encodeVarInt(6)...) // reason phrase length expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) } func TestWriteConnectionCloseTransportError(t *testing.T) { f := &ConnectionCloseFrame{ ErrorCode: 0xcafe, FrameType: 0xdeadbeef, ReasonPhrase: "foobar", } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Len(t, b, int(f.Length(protocol.Version1))) } func TestWriteConnectionCloseLength(t *testing.T) { f := &ConnectionCloseFrame{ IsApplicationError: true, ErrorCode: 0xcafe, ReasonPhrase: "foobar", } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Len(t, b, int(f.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/crypto_frame.go000066400000000000000000000051661513066070600207700ustar00rootroot00000000000000package wire import ( "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A CryptoFrame is a CRYPTO frame type CryptoFrame struct { Offset protocol.ByteCount Data []byte } func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) { startLen := len(b) frame := &CryptoFrame{} offset, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] frame.Offset = protocol.ByteCount(offset) dataLen, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] if dataLen > uint64(len(b)) { return nil, 0, io.EOF } if dataLen != 0 { frame.Data = make([]byte, dataLen) copy(frame.Data, b) } return frame, startLen - len(b) + int(dataLen), nil } func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeCrypto)) b = quicvarint.Append(b, uint64(f.Offset)) b = quicvarint.Append(b, uint64(len(f.Data))) b = append(b, f.Data...) return b, nil } // Length of a written frame func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount { return protocol.ByteCount(1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + len(f.Data)) } // MaxDataLen returns the maximum data length func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount { // pretend that the data size will be 1 bytes // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards headerLen := protocol.ByteCount(1 + quicvarint.Len(uint64(f.Offset)) + 1) if headerLen > maxSize { return 0 } maxDataLen := maxSize - headerLen if quicvarint.Len(uint64(maxDataLen)) != 1 { maxDataLen-- } return maxDataLen } // MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. // It returns if the frame was actually split. // The frame might not be split if: // * the size is large enough to fit the whole frame // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) { if f.Length(version) <= maxSize { return nil, false } n := f.MaxDataLen(maxSize) if n == 0 { return nil, true } newLen := protocol.ByteCount(len(f.Data)) - n new := &CryptoFrame{} new.Offset = f.Offset new.Data = make([]byte, newLen) // swap the data slices new.Data, f.Data = f.Data, new.Data copy(f.Data, new.Data[n:]) new.Data = new.Data[:n] f.Offset += n return new, true } quic-go-0.59.0/internal/wire/crypto_frame_test.go000066400000000000000000000073651513066070600220320ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseCryptoFrame(t *testing.T) { data := encodeVarInt(0xdecafbad) // offset data = append(data, encodeVarInt(6)...) // length data = append(data, []byte("foobar")...) frame, l, err := parseCryptoFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.ByteCount(0xdecafbad), frame.Offset) require.Equal(t, []byte("foobar"), frame.Data) require.Equal(t, len(data), l) } func TestParseCryptoFrameErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdecafbad) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) _, l, err := parseCryptoFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseCryptoFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteCryptoFrame(t *testing.T) { f := &CryptoFrame{ Offset: 0x123456, Data: []byte("foobar"), } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeCrypto)} expected = append(expected, encodeVarInt(0x123456)...) // offset expected = append(expected, encodeVarInt(6)...) // length expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestCryptoFrameMaxDataLength(t *testing.T) { const maxSize = 3000 data := make([]byte, maxSize) f := &CryptoFrame{ Offset: 0xdeadbeef, } var frameOneByteTooSmallCounter int for i := 1; i < maxSize; i++ { f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i)) if maxDataLen == 0 { // 0 means that no valid CRYPTO frame can be written // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Greater(t, len(b), i) continue } f.Data = data[:int(maxDataLen)] b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size if len(b) == i-1 { frameOneByteTooSmallCounter++ continue } require.Equal(t, i, len(b)) } require.Equal(t, 1, frameOneByteTooSmallCounter) } func TestCryptoFrameSplitting(t *testing.T) { f := &CryptoFrame{ Offset: 0x1337, Data: []byte("foobar"), } hdrLen := f.Length(protocol.Version1) - 6 new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, protocol.Version1) require.True(t, needsSplit) require.Equal(t, []byte("foo"), new.Data) require.Equal(t, protocol.ByteCount(0x1337), new.Offset) require.Equal(t, []byte("bar"), f.Data) require.Equal(t, protocol.ByteCount(0x1337+3), f.Offset) } func TestCryptoFrameNoSplitWhenEnoughSpace(t *testing.T) { f := &CryptoFrame{ Offset: 0x1337, Data: []byte("foobar"), } splitFrame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) require.False(t, needsSplit) require.Nil(t, splitFrame) } func TestCryptoFrameNoSplitWhenSizeTooSmall(t *testing.T) { f := &CryptoFrame{ Offset: 0x1337, Data: []byte("foobar"), } length := f.Length(protocol.Version1) - 6 for i := protocol.ByteCount(0); i <= length; i++ { splitFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) require.True(t, needsSplit) require.Nil(t, splitFrame) } splitFrame, needsSplit := f.MaybeSplitOffFrame(length+1, protocol.Version1) require.True(t, needsSplit) require.NotNil(t, splitFrame) } quic-go-0.59.0/internal/wire/data_blocked_frame.go000066400000000000000000000015471513066070600220430ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A DataBlockedFrame is a DATA_BLOCKED frame type DataBlockedFrame struct { MaximumData protocol.ByteCount } func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int, error) { offset, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, l, nil } func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeDataBlocked)) return quicvarint.Append(b, uint64(f.MaximumData)), nil } // Length of a written frame func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData))) } quic-go-0.59.0/internal/wire/data_blocked_frame_test.go000066400000000000000000000022561513066070600231000ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestParseDataBlocked(t *testing.T) { data := encodeVarInt(0x12345678) frame, l, err := parseDataBlockedFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.ByteCount(0x12345678), frame.MaximumData) require.Equal(t, len(data), l) } func TestParseDataBlockedErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0x12345678) _, l, err := parseDataBlockedFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseDataBlockedFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteDataBlocked(t *testing.T) { frame := DataBlockedFrame{MaximumData: 0xdeadbeef} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeDataBlocked)} expected = append(expected, encodeVarInt(0xdeadbeef)...) require.Equal(t, expected, b) require.Equal(t, protocol.ByteCount(1+quicvarint.Len(uint64(frame.MaximumData))), frame.Length(protocol.Version1)) } quic-go-0.59.0/internal/wire/datagram_frame.go000066400000000000000000000043301513066070600212200ustar00rootroot00000000000000package wire import ( "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // MaxDatagramSize is the maximum size of a DATAGRAM frame (RFC 9221). // By setting it to a large value, we allow all datagrams that fit into a QUIC packet. // The value is chosen such that it can still be encoded as a 2 byte varint. // This is a var and not a const so it can be set in tests. var MaxDatagramSize protocol.ByteCount = 16383 // A DatagramFrame is a DATAGRAM frame type DatagramFrame struct { DataLenPresent bool Data []byte } func parseDatagramFrame(b []byte, typ FrameType, _ protocol.Version) (*DatagramFrame, int, error) { startLen := len(b) f := &DatagramFrame{} f.DataLenPresent = uint64(typ)&0x1 > 0 var length uint64 if f.DataLenPresent { var err error var l int length, l, err = quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] if length > uint64(len(b)) { return nil, 0, io.EOF } } else { length = uint64(len(b)) } f.Data = make([]byte, length) copy(f.Data, b) return f, startLen - len(b) + int(length), nil } func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { typ := uint8(0x30) if f.DataLenPresent { typ ^= 0b1 } b = append(b, typ) if f.DataLenPresent { b = quicvarint.Append(b, uint64(len(f.Data))) } b = append(b, f.Data...) return b, nil } // MaxDataLen returns the maximum data length func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount { headerLen := protocol.ByteCount(1) if f.DataLenPresent { // pretend that the data size will be 1 bytes // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards headerLen++ } if headerLen > maxSize { return 0 } maxDataLen := maxSize - headerLen if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { maxDataLen-- } return maxDataLen } // Length of a written frame func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount { length := 1 + protocol.ByteCount(len(f.Data)) if f.DataLenPresent { length += protocol.ByteCount(quicvarint.Len(uint64(len(f.Data)))) } return length } quic-go-0.59.0/internal/wire/datagram_frame_test.go000066400000000000000000000100041513066070600222520ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseDatagramFrameWithLength(t *testing.T) { data := encodeVarInt(0x6) // length data = append(data, []byte("foobar")...) frame, l, err := parseDatagramFrame(data, 0x30^0x1, protocol.Version1) require.NoError(t, err) require.Equal(t, []byte("foobar"), frame.Data) require.True(t, frame.DataLenPresent) require.Equal(t, len(data), l) } func TestParseDatagramFrameWithoutLength(t *testing.T) { data := []byte("Lorem ipsum dolor sit amet") frame, l, err := parseDatagramFrame(data, 0x30, protocol.Version1) require.NoError(t, err) require.Equal(t, []byte("Lorem ipsum dolor sit amet"), frame.Data) require.False(t, frame.DataLenPresent) require.Equal(t, len(data), l) } func TestParseDatagramFrameErrorsOnLengthLongerThanFrame(t *testing.T) { data := encodeVarInt(0x6) // length data = append(data, []byte("fooba")...) _, _, err := parseDatagramFrame(data, 0x30^0x1, protocol.Version1) require.Equal(t, io.EOF, err) } func TestParseDatagramFrameErrorsOnEOFs(t *testing.T) { const typ = 0x30 ^ 0x1 data := encodeVarInt(6) // length data = append(data, []byte("foobar")...) _, l, err := parseDatagramFrame(data, typ, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err = parseDatagramFrame(data[0:i], typ, protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteDatagramFrameWithLength(t *testing.T) { f := &DatagramFrame{ DataLenPresent: true, Data: []byte("foobar"), } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x30 ^ 0x1} expected = append(expected, encodeVarInt(0x6)...) expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) } func TestWriteDatagramFrameWithoutLength(t *testing.T) { f := &DatagramFrame{Data: []byte("Lorem ipsum")} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x30} expected = append(expected, []byte("Lorem ipsum")...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) } func TestMaxDatagramLenWithoutDataLenPresent(t *testing.T) { const maxSize = 3000 data := make([]byte, maxSize) f := &DatagramFrame{} for i := 1; i < 3000; i++ { f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid DATAGRAM frame can be written // check that writing a minimal size DATAGRAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Greater(t, len(b), i) continue } f.Data = data[:int(maxDataLen)] b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Len(t, b, i) } } func TestMaxDatagramLenWithDataLenPresent(t *testing.T) { const maxSize = 3000 data := make([]byte, maxSize) f := &DatagramFrame{DataLenPresent: true} var frameOneByteTooSmallCounter int for i := 1; i < 3000; i++ { f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid DATAGRAM frame can be written // check that writing a minimal size DATAGRAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Greater(t, len(b), i) continue } f.Data = data[:int(maxDataLen)] b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a DATAGRAM frame of the desired size if len(b) == i-1 { frameOneByteTooSmallCounter++ continue } require.Len(t, b, i) } require.Equal(t, 1, frameOneByteTooSmallCounter) } quic-go-0.59.0/internal/wire/extended_header.go000066400000000000000000000123621513066070600214020ustar00rootroot00000000000000package wire import ( "encoding/binary" "errors" "fmt" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" ) // ErrInvalidReservedBits is returned when the reserved bits are incorrect. // When this error is returned, parsing continues, and an ExtendedHeader is returned. // This is necessary because we need to decrypt the packet in that case, // in order to avoid a timing side-channel. var ErrInvalidReservedBits = errors.New("invalid reserved bits") // ExtendedHeader is the header of a QUIC packet. type ExtendedHeader struct { Header typeByte byte KeyPhase protocol.KeyPhaseBit PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber parsedLen protocol.ByteCount } func (h *ExtendedHeader) parse(data []byte) (bool /* reserved bits valid */, error) { // read the (now unencrypted) first byte h.typeByte = data[0] h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 if protocol.ByteCount(len(data)) < h.Header.ParsedLen()+protocol.ByteCount(h.PacketNumberLen) { return false, io.EOF } pn, err := readPacketNumber(data[h.Header.ParsedLen():], h.PacketNumberLen) if err != nil { return true, nil } h.PacketNumber = pn reservedBitsValid := h.typeByte&0xc == 0 h.parsedLen = h.Header.ParsedLen() + protocol.ByteCount(h.PacketNumberLen) return reservedBitsValid, err } // Append appends the Header. func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) { if h.DestConnectionID.Len() > protocol.MaxConnIDLen { return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) } if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) } var packetType uint8 if v == protocol.Version2 { switch h.Type { case protocol.PacketTypeInitial: packetType = 0b01 case protocol.PacketType0RTT: packetType = 0b10 case protocol.PacketTypeHandshake: packetType = 0b11 case protocol.PacketTypeRetry: packetType = 0b00 } } else { switch h.Type { case protocol.PacketTypeInitial: packetType = 0b00 case protocol.PacketType0RTT: packetType = 0b01 case protocol.PacketTypeHandshake: packetType = 0b10 case protocol.PacketTypeRetry: packetType = 0b11 } } firstByte := 0xc0 | packetType<<4 if h.Type != protocol.PacketTypeRetry { // Retry packets don't have a packet number firstByte |= uint8(h.PacketNumberLen - 1) } b = append(b, firstByte) b = append(b, make([]byte, 4)...) binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version)) b = append(b, uint8(h.DestConnectionID.Len())) b = append(b, h.DestConnectionID.Bytes()...) b = append(b, uint8(h.SrcConnectionID.Len())) b = append(b, h.SrcConnectionID.Bytes()...) //nolint:exhaustive switch h.Type { case protocol.PacketTypeRetry: b = append(b, h.Token...) return b, nil case protocol.PacketTypeInitial: b = quicvarint.Append(b, uint64(len(h.Token))) b = append(b, h.Token...) } b = quicvarint.AppendWithLen(b, uint64(h.Length), 2) return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen) } // ParsedLen returns the number of bytes that were consumed when parsing the header func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { return h.parsedLen } // GetLength determines the length of the Header. func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount { length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ if h.Type == protocol.PacketTypeInitial { length += protocol.ByteCount(quicvarint.Len(uint64(len(h.Token))) + len(h.Token)) } return length } // Log logs the Header func (h *ExtendedHeader) Log(logger utils.Logger) { var token string if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { if len(h.Token) == 0 { token = "Token: (empty), " } else { token = fmt.Sprintf("Token: %#x, ", h.Token) } if h.Type == protocol.PacketTypeRetry { logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) return } } logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) } func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) { switch pnLen { case protocol.PacketNumberLen1: b = append(b, uint8(pn)) case protocol.PacketNumberLen2: buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(pn)) b = append(b, buf...) case protocol.PacketNumberLen3: buf := make([]byte, 4) binary.BigEndian.PutUint32(buf, uint32(pn)) b = append(b, buf[1:]...) case protocol.PacketNumberLen4: buf := make([]byte, 4) binary.BigEndian.PutUint32(buf, uint32(pn)) b = append(b, buf...) default: return nil, fmt.Errorf("invalid packet number length: %d", pnLen) } return b, nil } quic-go-0.59.0/internal/wire/extended_header_test.go000066400000000000000000000223171513066070600224420ustar00rootroot00000000000000package wire import ( "bytes" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestWritesLongHeaderVersion1(t *testing.T) { header := &ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}), Version: 0x1020304, Length: 1234, }, PacketNumber: 0xdecaf, PacketNumberLen: protocol.PacketNumberLen3, } b, err := header.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{ 0xc0 | 0x2<<4 | 0x2, 0x1, 0x2, 0x3, 0x4, // version number 0x6, // dest connection ID length 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID 0x8, // src connection ID length 0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID } expected = append(expected, encodeVarInt(1234)...) // length expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number require.Equal(t, expected, b) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version1)) } func TestWritesHandshakePacketVersion2(t *testing.T) { header := &ExtendedHeader{ Header: Header{ Version: protocol.Version2, Type: protocol.PacketTypeHandshake, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } b, err := header.Append(nil, protocol.Version2) require.NoError(t, err) require.Equal(t, byte(0b11), b[0]>>4&0b11) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version2)) } func TestWritesHeaderWith20ByteConnectionID(t *testing.T) { srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) header := &ExtendedHeader{ Header: Header{ SrcConnectionID: srcConnID, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long Version: 0x1020304, Type: 0x5, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } b, err := header.Append(nil, protocol.Version1) require.NoError(t, err) require.Contains(t, string(b), string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version1)) } func TestWritesInitialContainingToken(t *testing.T) { token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") header := &ExtendedHeader{ Header: Header{ Version: 0x1020304, Type: protocol.PacketTypeInitial, Token: token, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } b, err := header.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, byte(0), b[0]>>4&0b11) expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) require.Contains(t, string(b), string(expectedSubstring)) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version1)) } func TestUses2ByteEncodingForLengthOnInitialPackets(t *testing.T) { header := &ExtendedHeader{ Header: Header{ Version: 0x1020304, Type: protocol.PacketTypeInitial, Length: 37, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } b, err := header.Append(nil, protocol.Version1) require.NoError(t, err) lengthEncoded := quicvarint.AppendWithLen(nil, 37, 2) require.Equal(t, lengthEncoded, b[len(b)-6:len(b)-4]) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version1)) } func TestWritesInitialPacketVersion2(t *testing.T) { header := &ExtendedHeader{ Header: Header{ Version: protocol.Version2, Type: protocol.PacketTypeInitial, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } b, err := header.Append(nil, protocol.Version2) require.NoError(t, err) require.Equal(t, byte(0b01), b[0]>>4&0b11) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version2)) } func TestWrites0RTTPacketVersion2(t *testing.T) { header := &ExtendedHeader{ Header: Header{ Version: protocol.Version2, Type: protocol.PacketType0RTT, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } b, err := header.Append(nil, protocol.Version2) require.NoError(t, err) require.Equal(t, byte(0b10), b[0]>>4&0b11) require.Equal(t, protocol.ByteCount(len(b)), header.GetLength(protocol.Version2)) } func TestWritesRetryPacket(t *testing.T) { token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") for _, version := range []protocol.Version{protocol.Version1, protocol.Version2} { t.Run(version.String(), func(t *testing.T) { header := &ExtendedHeader{Header: Header{ Version: version, Type: protocol.PacketTypeRetry, Token: token, }} b, err := header.Append(nil, version) require.NoError(t, err) var expected []byte switch version { case protocol.Version1: expected = append(expected, 0xc0|0b11<<4) case protocol.Version2: expected = append(expected, 0xc0) } expected = appendVersion(expected, version) expected = append(expected, 0x0) // dest connection ID length expected = append(expected, 0x0) // src connection ID length expected = append(expected, token...) require.Equal(t, expected, b) }) } } func TestLogsLongHeaders(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) (&ExtendedHeader{ Header: Header{ DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}), Type: protocol.PacketTypeHandshake, Length: 54321, Version: 0xfeed, }, PacketNumber: 1337, PacketNumberLen: protocol.PacketNumberLen2, }).Log(logger) require.Contains(t, buf.String(), "Long Header{Type: Handshake, DestConnectionID: deadbeefcafe1337, SrcConnectionID: decafbad13371337, PacketNumber: 1337, PacketNumberLen: 2, Length: 54321, Version: 0xfeed}") } func TestLogsInitialPacketsWithToken(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) (&ExtendedHeader{ Header: Header{ DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeInitial, Token: []byte{0xde, 0xad, 0xbe, 0xef}, Length: 100, Version: 0xfeed, }, PacketNumber: 42, PacketNumberLen: protocol.PacketNumberLen2, }).Log(logger) require.Contains(t, buf.String(), "Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0xdeadbeef, PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}") } func TestLogsInitialPacketsWithoutToken(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) (&ExtendedHeader{ Header: Header{ DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeInitial, Length: 100, Version: 0xfeed, }, PacketNumber: 42, PacketNumberLen: protocol.PacketNumberLen2, }).Log(logger) require.Contains(t, buf.String(), "Long Header{Type: Initial, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: (empty), PacketNumber: 42, PacketNumberLen: 2, Length: 100, Version: 0xfeed}") } func TestLogsRetryPacketsWithToken(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) (&ExtendedHeader{ Header: Header{ DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeRetry, Token: []byte{0x12, 0x34, 0x56}, Version: 0xfeed, }, }).Log(logger) require.Contains(t, buf.String(), "Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}") } func BenchmarkParseExtendedHeader(b *testing.B) { b.ReportAllocs() data, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}), SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}), Version: protocol.Version1, Length: 1234, }, PacketNumber: 0xdecaf, PacketNumberLen: protocol.PacketNumberLen3, }).Append(nil, protocol.Version1) if err != nil { b.Fatal(err) } data = append(data, make([]byte, 1231)...) for b.Loop() { hdr, _, _, err := ParsePacket(data) if err != nil { b.Fatal(err) } if _, err := hdr.ParseExtended(data); err != nil { b.Fatal(err) } } } quic-go-0.59.0/internal/wire/frame.go000066400000000000000000000015521513066070600173630ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" ) // A Frame in QUIC type Frame interface { Append(b []byte, version protocol.Version) ([]byte, error) Length(version protocol.Version) protocol.ByteCount } // IsProbingFrame returns true if the frame is a probing frame. // See section 9.1 of RFC 9000. func IsProbingFrame(f Frame) bool { switch f.(type) { case *PathChallengeFrame, *PathResponseFrame, *NewConnectionIDFrame: return true } return false } // IsProbingFrameType returns true if the FrameType is a probing frame. // See section 9.1 of RFC 9000. func IsProbingFrameType(f FrameType) bool { //nolint:exhaustive // PATH_CHALLENGE, PATH_RESPONSE and NEW_CONNECTION_ID are the only probing frames switch f { case FrameTypePathChallenge, FrameTypePathResponse, FrameTypeNewConnectionID: return true default: return false } } quic-go-0.59.0/internal/wire/frame_parser.go000066400000000000000000000137431513066070600207440ustar00rootroot00000000000000package wire import ( "errors" "fmt" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" ) var errUnknownFrameType = errors.New("unknown frame type") // The FrameParser parses QUIC frames, one by one. type FrameParser struct { ackDelayExponent uint8 supportsDatagrams bool supportsResetStreamAt bool supportsAckFrequency bool // To avoid allocating when parsing, keep a single ACK frame struct. // It is used over and over again. ackFrame *AckFrame } // NewFrameParser creates a new frame parser. func NewFrameParser(supportsDatagrams, supportsResetStreamAt, supportsAckFrequency bool) *FrameParser { return &FrameParser{ supportsDatagrams: supportsDatagrams, supportsResetStreamAt: supportsResetStreamAt, supportsAckFrequency: supportsAckFrequency, ackFrame: &AckFrame{}, } } // ParseType parses the frame type of the next frame. // It skips over PADDING frames. func (p *FrameParser) ParseType(b []byte, encLevel protocol.EncryptionLevel) (FrameType, int, error) { var parsed int for len(b) != 0 { typ, l, err := quicvarint.Parse(b) parsed += l if err != nil { return 0, parsed, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, ErrorMessage: err.Error(), } } b = b[l:] if typ == 0x0 { // skip PADDING frames continue } ft := FrameType(typ) valid := ft.isValidRFC9000() || (p.supportsDatagrams && ft.IsDatagramFrameType()) || (p.supportsResetStreamAt && ft == FrameTypeResetStreamAt) || (p.supportsAckFrequency && (ft == FrameTypeAckFrequency || ft == FrameTypeImmediateAck)) if !valid { return 0, parsed, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: typ, ErrorMessage: errUnknownFrameType.Error(), } } if !ft.isAllowedAtEncLevel(encLevel) { return 0, parsed, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: typ, ErrorMessage: fmt.Sprintf("%d not allowed at encryption level %s", ft, encLevel), } } return ft, parsed, nil } return 0, parsed, io.EOF } func (p *FrameParser) ParseStreamFrame(frameType FrameType, data []byte, v protocol.Version) (*StreamFrame, int, error) { frame, n, err := ParseStreamFrame(data, frameType, v) if err != nil { return nil, n, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: uint64(frameType), ErrorMessage: err.Error(), } } return frame, n, nil } func (p *FrameParser) ParseAckFrame(frameType FrameType, data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (*AckFrame, int, error) { ackDelayExponent := p.ackDelayExponent if encLevel != protocol.Encryption1RTT { ackDelayExponent = protocol.DefaultAckDelayExponent } p.ackFrame.Reset() l, err := parseAckFrame(p.ackFrame, data, frameType, ackDelayExponent, v) if err != nil { return nil, l, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: uint64(frameType), ErrorMessage: err.Error(), } } return p.ackFrame, l, nil } func (p *FrameParser) ParseDatagramFrame(frameType FrameType, data []byte, v protocol.Version) (*DatagramFrame, int, error) { f, l, err := parseDatagramFrame(data, frameType, v) if err != nil { return nil, 0, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: uint64(frameType), ErrorMessage: err.Error(), } } return f, l, nil } // ParseLessCommonFrame parses everything except STREAM, ACK or DATAGRAM. // These cases should be handled separately for performance reasons. func (p *FrameParser) ParseLessCommonFrame(frameType FrameType, data []byte, v protocol.Version) (Frame, int, error) { var frame Frame var l int var err error //nolint:exhaustive // Common frames should already be handled. switch frameType { case FrameTypePing: frame = &PingFrame{} case FrameTypeResetStream: frame, l, err = parseResetStreamFrame(data, false, v) case FrameTypeStopSending: frame, l, err = parseStopSendingFrame(data, v) case FrameTypeCrypto: frame, l, err = parseCryptoFrame(data, v) case FrameTypeNewToken: frame, l, err = parseNewTokenFrame(data, v) case FrameTypeMaxData: frame, l, err = parseMaxDataFrame(data, v) case FrameTypeMaxStreamData: frame, l, err = parseMaxStreamDataFrame(data, v) case FrameTypeBidiMaxStreams, FrameTypeUniMaxStreams: frame, l, err = parseMaxStreamsFrame(data, frameType, v) case FrameTypeDataBlocked: frame, l, err = parseDataBlockedFrame(data, v) case FrameTypeStreamDataBlocked: frame, l, err = parseStreamDataBlockedFrame(data, v) case FrameTypeBidiStreamBlocked, FrameTypeUniStreamBlocked: frame, l, err = parseStreamsBlockedFrame(data, frameType, v) case FrameTypeNewConnectionID: frame, l, err = parseNewConnectionIDFrame(data, v) case FrameTypeRetireConnectionID: frame, l, err = parseRetireConnectionIDFrame(data, v) case FrameTypePathChallenge: frame, l, err = parsePathChallengeFrame(data, v) case FrameTypePathResponse: frame, l, err = parsePathResponseFrame(data, v) case FrameTypeConnectionClose, FrameTypeApplicationClose: frame, l, err = parseConnectionCloseFrame(data, frameType, v) case FrameTypeHandshakeDone: frame = &HandshakeDoneFrame{} case FrameTypeResetStreamAt: frame, l, err = parseResetStreamFrame(data, true, v) case FrameTypeAckFrequency: frame, l, err = parseAckFrequencyFrame(data, v) case FrameTypeImmediateAck: frame = &ImmediateAckFrame{} default: err = errUnknownFrameType } if err != nil { return frame, l, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: uint64(frameType), ErrorMessage: err.Error(), } } return frame, l, err } // SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters). // This value is used to scale the ACK Delay field in the ACK frame. func (p *FrameParser) SetAckDelayExponent(exp uint8) { p.ackDelayExponent = exp } func replaceUnexpectedEOF(e error) error { if e == io.ErrUnexpectedEOF { return io.EOF } return e } quic-go-0.59.0/internal/wire/frame_parser_test.go000066400000000000000000000606071513066070600220040ustar00rootroot00000000000000package wire import ( "bytes" "crypto/rand" "fmt" "io" "slices" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestFrameTypeParsingReturnsNilWhenNothingToRead(t *testing.T) { parser := NewFrameParser(true, true, true) frameType, l, err := parser.ParseType(nil, protocol.Encryption1RTT) require.Equal(t, io.EOF, err) require.Zero(t, frameType) require.Zero(t, l) } func TestParseLessCommonFrameReturnsEOFWhenNothingToRead(t *testing.T) { parser := NewFrameParser(true, true, true) l, f, err := parser.ParseLessCommonFrame(FrameTypeMaxStreamData, nil, protocol.Version1) require.IsType(t, &qerr.TransportError{}, err) require.Zero(t, l) require.Zero(t, f) } func TestFrameParsingSkipsPaddingFrames(t *testing.T) { parser := NewFrameParser(true, true, true) b := []byte{0, 0} // 2 PADDING frames b, err := (&PingFrame{}).Append(b, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, 3, l) require.Equal(t, FrameTypePing, frameType) frame, l, err := parser.ParseLessCommonFrame(frameType, b[1:], protocol.Version1) require.NoError(t, err) require.Zero(t, l) require.IsType(t, &PingFrame{}, frame) } func TestFrameParsingHandlesPaddingAtEnd(t *testing.T) { parser := NewFrameParser(true, true, true) b := []byte{0, 0, 0} _, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.Equal(t, io.EOF, err) require.Equal(t, 3, l) } func TestFrameParsingParsesSingleFrame(t *testing.T) { parser := NewFrameParser(true, true, true) var b []byte for range 10 { var err error b, err = (&PingFrame{}).Append(b, protocol.Version1) require.NoError(t, err) } frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, FrameTypePing, frameType) require.Equal(t, 1, l) frame, l, err := parser.ParseLessCommonFrame(frameType, b, protocol.Version1) require.NoError(t, err) require.Zero(t, l) require.IsType(t, &PingFrame{}, frame) } func TestFrameParserACK(t *testing.T) { parser := NewFrameParser(true, true, true) f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, FrameTypeAck, frameType) require.Equal(t, 1, l) frame, l, err := parser.ParseAckFrame(frameType, b[l:], protocol.Encryption1RTT, protocol.Version1) require.NoError(t, err) require.NotNil(t, frame) require.Equal(t, protocol.PacketNumber(0x13), frame.LargestAcked()) require.Equal(t, len(b)-1, l) } func TestFrameParserAckDelay(t *testing.T) { t.Run("1-RTT", func(t *testing.T) { testFrameParserAckDelay(t, protocol.Encryption1RTT) }) t.Run("Handshake", func(t *testing.T) { testFrameParserAckDelay(t, protocol.EncryptionHandshake) }) } func testFrameParserAckDelay(t *testing.T, encLevel protocol.EncryptionLevel) { parser := NewFrameParser(true, true, true) parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) f := &AckFrame{ AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: time.Second, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, encLevel) require.NoError(t, err) require.Equal(t, FrameTypeAck, frameType) require.Equal(t, 1, l) frame, l, err := parser.ParseAckFrame(frameType, b[l:], encLevel, protocol.Version1) require.NoError(t, err) require.Equal(t, len(b)-1, l) if encLevel == protocol.Encryption1RTT { require.Equal(t, 4*time.Second, frame.DelayTime) } else { require.Equal(t, time.Second, frame.DelayTime) } } func checkFrameUnsupported(t *testing.T, err error, expectedFrameType uint64) { t.Helper() require.ErrorContains(t, err, errUnknownFrameType.Error()) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) require.Equal(t, expectedFrameType, transportErr.FrameType) require.Equal(t, "unknown frame type", transportErr.ErrorMessage) } func TestFrameParserStreamFrames(t *testing.T) { parser := NewFrameParser(true, true, true) f := &StreamFrame{ StreamID: 0x42, Offset: 0x1337, Fin: true, Data: []byte("foobar"), } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, FrameType(0xd), frameType) require.True(t, frameType.IsStreamFrameType()) require.Equal(t, 1, l) // ParseLessCommonFrame should not handle Stream Frames frame, l, err := parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1) checkFrameUnsupported(t, err, 0xd) require.Nil(t, frame) require.Zero(t, l) } func TestParseStreamFrameWrapsError(t *testing.T) { parser := NewFrameParser(true, true, true) f := &StreamFrame{ StreamID: 0x1234, Offset: 0x1000, Data: []byte("hello world"), DataLenPresent: true, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) // Corrupt the buffer to trigger a parse error b = b[:len(b)-2] // Remove last 2 bytes to cause an EOF frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) frame, n, err := parser.ParseStreamFrame(frameType, b[l:], protocol.Version1) require.Nil(t, frame) require.Zero(t, n) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) require.Equal(t, uint64(frameType), transportErr.FrameType) require.Contains(t, transportErr.Error(), "EOF") } func TestParseStreamFrameSuccess(t *testing.T) { parser := NewFrameParser(true, true, true) original := &StreamFrame{ StreamID: 0x1234, Offset: 0x1000, Fin: true, Data: []byte("hello world"), DataLenPresent: true, } b, err := original.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.True(t, frameType.IsStreamFrameType()) require.Equal(t, FrameType(0x0f), frameType) // STREAM | OFF | LEN | FIN parsed, n, err := parser.ParseStreamFrame(frameType, b[l:], protocol.Version1) require.NoError(t, err) require.NotNil(t, parsed) require.Equal(t, len(b)-l, n) require.Equal(t, original.StreamID, parsed.StreamID) require.Equal(t, original.Offset, parsed.Offset) require.Equal(t, original.Fin, parsed.Fin) require.Equal(t, original.DataLenPresent, parsed.DataLenPresent) require.Equal(t, original.Data, parsed.Data) } func TestFrameParserFrames(t *testing.T) { tests := []struct { name string frameType FrameType frame Frame }{ { name: "MAX_DATA", frameType: FrameTypeMaxData, frame: &MaxDataFrame{MaximumData: 0xcafe}, }, { name: "MAX_STREAM_DATA", frameType: FrameTypeMaxStreamData, frame: &MaxStreamDataFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdecafbad}, }, { name: "RESET_STREAM", frameType: FrameTypeResetStream, frame: &ResetStreamFrame{ StreamID: 0xdeadbeef, FinalSize: 0xdecafbad1234, ErrorCode: 0x1337, }, }, { name: "STOP_SENDING", frameType: FrameTypeStopSending, frame: &StopSendingFrame{StreamID: 0x42}, }, { name: "CRYPTO", frameType: FrameTypeCrypto, frame: &CryptoFrame{Offset: 0x1337, Data: []byte("lorem ipsum")}, }, { name: "NEW_TOKEN", frameType: FrameTypeNewToken, frame: &NewTokenFrame{Token: []byte("foobar")}, }, { name: "MAX_STREAMS", frameType: FrameTypeBidiMaxStreams, frame: &MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 0x1337}, }, { name: "DATA_BLOCKED", frameType: FrameTypeDataBlocked, frame: &DataBlockedFrame{MaximumData: 0x1234}, }, { name: "STREAM_DATA_BLOCKED", frameType: FrameTypeStreamDataBlocked, frame: &StreamDataBlockedFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdead}, }, { name: "STREAMS_BLOCKED", frameType: FrameTypeBidiStreamBlocked, frame: &StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 0x1234567}, }, { name: "NEW_CONNECTION_ID", frameType: FrameTypeNewConnectionID, frame: &NewConnectionIDFrame{ SequenceNumber: 0x1337, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, }, }, { name: "RETIRE_CONNECTION_ID", frameType: FrameTypeRetireConnectionID, frame: &RetireConnectionIDFrame{SequenceNumber: 0x1337}, }, { name: "PATH_CHALLENGE", frameType: FrameTypePathChallenge, frame: &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, }, { name: "PATH_RESPONSE", frameType: FrameTypePathResponse, frame: &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, }, { name: "CONNECTION_CLOSE", frameType: FrameTypeConnectionClose, frame: &ConnectionCloseFrame{IsApplicationError: false, ReasonPhrase: "foobar"}, }, { name: "APPLICATION_CLOSE", frameType: FrameTypeApplicationClose, frame: &ConnectionCloseFrame{IsApplicationError: true, ReasonPhrase: "foobar"}, }, { name: "HANDSHAKE_DONE", frameType: FrameTypeHandshakeDone, frame: &HandshakeDoneFrame{}, }, { name: "RESET_STREAM_AT", frameType: FrameTypeResetStreamAt, frame: &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef}, }, { name: "ACK_FREQUENCY", frameType: FrameTypeAckFrequency, frame: &AckFrequencyFrame{ SequenceNumber: 0x1337, AckElicitingThreshold: 0x42, RequestMaxAckDelay: 123 * time.Second, ReorderingThreshold: 0xcafe, }, }, { name: "IMMEDIATE_ACK", frameType: FrameTypeImmediateAck, frame: &ImmediateAckFrame{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { parser := NewFrameParser(true, true, true) b, err := test.frame.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, test.frameType, frameType) require.Equal(t, quicvarint.Len(uint64(test.frameType)), l) frame, l, err := parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1) require.NoError(t, err) require.Equal(t, test.frame, frame) require.Equal(t, len(b)-quicvarint.Len(uint64(test.frameType)), l) }) } } func TestFrameAllowedAtEncLevel(t *testing.T) { type testCase struct { name string frameType FrameType frame Frame allowedInitial bool allowedHandshake bool allowedZeroRTT bool allowedOneRTT bool } for _, tc := range []testCase{ { name: "CRYPTO_FRAME", frameType: FrameTypeCrypto, frame: &CryptoFrame{Offset: 0, Data: []byte("foo")}, allowedInitial: true, allowedHandshake: true, allowedZeroRTT: false, allowedOneRTT: true, }, { name: "ACK_FRAME", frameType: FrameTypeAck, frame: &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 1}}}, allowedInitial: true, allowedHandshake: true, allowedZeroRTT: false, allowedOneRTT: true, }, { name: "CONNECTION_CLOSE_FRAME", frameType: FrameTypeConnectionClose, frame: &ConnectionCloseFrame{IsApplicationError: false, ReasonPhrase: "err"}, allowedInitial: true, allowedHandshake: true, allowedZeroRTT: false, allowedOneRTT: true, }, { name: "PING_FRAME", frameType: FrameTypePing, frame: &PingFrame{}, allowedInitial: true, allowedHandshake: true, allowedZeroRTT: true, allowedOneRTT: true, }, { name: "NEW_TOKEN_FRAME", frameType: FrameTypeNewToken, frame: &NewTokenFrame{Token: []byte("tok")}, allowedInitial: false, allowedHandshake: false, allowedZeroRTT: false, allowedOneRTT: true, }, { name: "PATH_RESPONSE_FRAME", frameType: FrameTypePathResponse, frame: &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, allowedInitial: false, allowedHandshake: false, allowedZeroRTT: false, allowedOneRTT: true, }, { name: "RETIRE_CONNECTION_ID_FRAME", frameType: FrameTypeRetireConnectionID, frame: &RetireConnectionIDFrame{SequenceNumber: 1}, allowedInitial: false, allowedHandshake: false, allowedZeroRTT: false, allowedOneRTT: true, }, { name: "MAX_DATA_FRAME", frameType: FrameTypeMaxData, frame: &MaxDataFrame{MaximumData: 1}, allowedInitial: false, allowedHandshake: false, allowedZeroRTT: true, allowedOneRTT: true, }, { name: "STREAM_FRAME", frameType: FrameType(0x8), frame: &StreamFrame{StreamID: 1, Data: []byte("foobar")}, allowedInitial: false, allowedHandshake: false, allowedZeroRTT: true, allowedOneRTT: true, }, } { for _, encLevel := range []protocol.EncryptionLevel{ protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT, } { t.Run(fmt.Sprintf("%s/%v", tc.name, encLevel), func(t *testing.T) { var allowed bool switch encLevel { case protocol.EncryptionInitial: allowed = tc.allowedInitial case protocol.EncryptionHandshake: allowed = tc.allowedHandshake case protocol.Encryption0RTT: allowed = tc.allowedZeroRTT case protocol.Encryption1RTT: allowed = tc.allowedOneRTT } parser := NewFrameParser(true, true, true) b, err := tc.frame.Append(nil, protocol.Version1) require.NoError(t, err) frameType, _, err := parser.ParseType(b, encLevel) if allowed { require.NoError(t, err) require.Equal(t, tc.frameType, frameType) } else { require.Error(t, err) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) } }) } } } func TestFrameParserDatagramFrame(t *testing.T) { parser := NewFrameParser(true, true, true) f := &DatagramFrame{ Data: []byte("foobar"), } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, FrameTypeDatagramNoLength, frameType) require.Equal(t, 1, l) // ParseLessCommonFrame should not be used to handle DATAGRAM frames _, _, err = parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1) require.Error(t, err) // parseDatagramFrame should be used for this type datagramFrame, l, err := parser.ParseDatagramFrame(frameType, b[l:], protocol.Version1) require.NoError(t, err) require.IsType(t, &DatagramFrame{}, datagramFrame) require.Equal(t, 6, l) require.Equal(t, f.Data, datagramFrame.Data) } func TestFrameParserDatagramUnsupported(t *testing.T) { parser := NewFrameParser(false, true, true) f := &DatagramFrame{Data: []byte("foobar")} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) _, _, err = parser.ParseType(b, protocol.Encryption1RTT) checkFrameUnsupported(t, err, 0x30) } func TestFrameParserResetStreamAtUnsupported(t *testing.T) { parser := NewFrameParser(true, false, true) f := &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) _, _, err = parser.ParseType(b, protocol.Encryption1RTT) checkFrameUnsupported(t, err, uint64(FrameTypeResetStreamAt)) } func TestFrameParserAckFrequencyUnsupported(t *testing.T) { parser := NewFrameParser(true, true, false) t.Run("ACK_FREQUENCY", func(t *testing.T) { f := &AckFrequencyFrame{ SequenceNumber: 1337, AckElicitingThreshold: 42, RequestMaxAckDelay: 42 * time.Millisecond, ReorderingThreshold: 1234, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) _, _, err = parser.ParseType(b, protocol.Encryption1RTT) checkFrameUnsupported(t, err, uint64(FrameTypeAckFrequency)) }) t.Run("IMMEDIATE_ACK", func(t *testing.T) { f := &ImmediateAckFrame{} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) _, _, err = parser.ParseType(b, protocol.Encryption1RTT) checkFrameUnsupported(t, err, uint64(FrameTypeImmediateAck)) }) } func TestFrameParserInvalidFrameType(t *testing.T) { parser := NewFrameParser(true, true, true) _, l, err := parser.ParseType(encodeVarInt(0x42), protocol.Encryption1RTT) require.Equal(t, 2, l) require.Error(t, err) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) } func TestFrameParsingErrorsOnInvalidFrames(t *testing.T) { parser := NewFrameParser(true, true, true) f := &MaxStreamDataFrame{ StreamID: 0x1337, MaximumStreamData: 0xdeadbeef, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) frameType, l, err := parser.ParseType(b[:len(b)-2], protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, FrameTypeMaxStreamData, frameType) require.Equal(t, 1, l) _, _, err = parser.ParseLessCommonFrame(frameType, b[1:len(b)-2], protocol.Version1) require.Error(t, err) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) } func writeFrames(tb testing.TB, frames ...Frame) []byte { var b []byte for _, f := range frames { var err error b, err = f.Append(b, protocol.Version1) require.NoError(tb, err) } return b } // This function is used in benchmarks, and also to ensure zero allocation for STREAM frame parsing. // We can therefore not use the require framework, as it allocates. func parseFrames(tb testing.TB, parser *FrameParser, data []byte, frames ...Frame) { for _, expectedFrame := range frames { frameType, l, err := parser.ParseType(data, protocol.Encryption1RTT) if err != nil { tb.Fatal(err) } data = data[l:] if frameType.IsStreamFrameType() { sf := expectedFrame.(*StreamFrame) frame, l, err := ParseStreamFrame(data, frameType, protocol.Version1) if err != nil { tb.Fatal(err) } if sf.StreamID != frame.StreamID || sf.Offset != frame.Offset { tb.Fatalf("STREAM frame does not match: %v vs %v", sf, frame) } frame.PutBack() data = data[l:] continue } if frameType.IsAckFrameType() { af, ok := expectedFrame.(*AckFrame) if !ok { tb.Fatalf("expected ACK, but got %v", expectedFrame) } f, l, err := parser.ParseAckFrame(frameType, data, protocol.Encryption1RTT, protocol.Version1) if f.DelayTime != af.DelayTime || f.ECNCE != af.ECNCE || f.ECT0 != af.ECT0 || f.ECT1 != af.ECT1 { tb.Fatal(err) } if f.DelayTime != af.DelayTime { tb.Fatalf("ACK frame does not match: %v vs %v", af, f) } if !slices.Equal(f.AckRanges, af.AckRanges) { tb.Fatalf("ACK frame ACK ranges don't match: %v vs %v", af, f) } data = data[l:] continue } if frameType.IsDatagramFrameType() { df, ok := expectedFrame.(*DatagramFrame) if !ok { tb.Fatalf("expected DATAGRAM, but got %v", expectedFrame) } f, l, err := parser.ParseDatagramFrame(frameType, data, protocol.Version1) if err != nil { tb.Fatal(err) } if df.DataLenPresent != f.DataLenPresent || !bytes.Equal(df.Data, f.Data) { tb.Fatalf("DATAGRAM frame does not match: %v vs %v", df, f) } data = data[l:] continue } f, l, err := parser.ParseLessCommonFrame(frameType, data, protocol.Version1) if err != nil { tb.Fatal(err) } data = data[l:] switch frameType { case FrameTypeMaxData: mdf, ok := expectedFrame.(*MaxDataFrame) if !ok { tb.Fatalf("expected MAX_DATA, but got %v", expectedFrame) } if *f.(*MaxDataFrame) != *mdf { tb.Fatalf("MAX_DATA frame does not match: %v vs %v", f, mdf) } case FrameTypeUniMaxStreams: msf, ok := expectedFrame.(*MaxStreamsFrame) if !ok { tb.Fatalf("expected MAX_STREAMS, but got %v", expectedFrame) } if *f.(*MaxStreamsFrame) != *msf { tb.Fatalf("MAX_STREAMS frame does not match: %v vs %v", f, msf) } case FrameTypeMaxStreamData: mdf, ok := expectedFrame.(*MaxStreamDataFrame) if !ok { tb.Fatalf("expected MAX_STREAM_DATA, but got %v", expectedFrame) } if *f.(*MaxStreamDataFrame) != *mdf { tb.Fatalf("MAX_STREAM_DATA frame does not match: %v vs %v", f, mdf) } case FrameTypeCrypto: cf, ok := expectedFrame.(*CryptoFrame) if !ok { tb.Fatalf("expected CRYPTO, but got %v", expectedFrame) } frame := f.(*CryptoFrame) if frame.Offset != cf.Offset || !bytes.Equal(frame.Data, cf.Data) { tb.Fatalf("CRYPTO frame does not match: %v vs %v", f, cf) } case FrameTypePing: _ = f.(*PingFrame) case FrameTypeResetStream: rsf, ok := expectedFrame.(*ResetStreamFrame) if !ok { tb.Fatalf("expected RESET_STREAM, but got %v", expectedFrame) } if *f.(*ResetStreamFrame) != *rsf { tb.Fatalf("RESET_STREAM frame does not match: %v vs %v", f, rsf) } continue default: tb.Fatalf("Frame type not supported in benchmark or should not occur: %v", frameType) } } } func TestFrameParserAllocs(t *testing.T) { t.Run("STREAM", func(t *testing.T) { var frames []Frame for i := range 10 { frames = append(frames, &StreamFrame{ StreamID: protocol.StreamID(1337 + i), Offset: protocol.ByteCount(1e7 + i), Data: make([]byte, 200+i), DataLenPresent: true, }) } require.Zero(t, testFrameParserAllocs(t, frames)) }) t.Run("ACK", func(t *testing.T) { var frames []Frame for i := range 10 { frames = append(frames, &AckFrame{ AckRanges: []AckRange{ {Smallest: protocol.PacketNumber(5000 + i), Largest: protocol.PacketNumber(5200 + i)}, {Smallest: protocol.PacketNumber(1 + i), Largest: protocol.PacketNumber(4200 + i)}, }, DelayTime: time.Duration(int64(time.Millisecond) * int64(i)), ECT0: uint64(5000 + i), ECT1: uint64(i), ECNCE: uint64(10 + i), }) } require.Zero(t, testFrameParserAllocs(t, frames)) }) } func testFrameParserAllocs(t *testing.T, frames []Frame) float64 { buf := writeFrames(t, frames...) parser := NewFrameParser(true, true, true) parser.SetAckDelayExponent(3) return testing.AllocsPerRun(100, func() { parseFrames(t, parser, buf, frames...) }) } func BenchmarkParseOtherFrames(b *testing.B) { frames := []Frame{ &MaxDataFrame{MaximumData: 123456}, &MaxStreamsFrame{MaxStreamNum: 10}, &MaxStreamDataFrame{StreamID: 1337, MaximumStreamData: 1e6}, &CryptoFrame{Offset: 1000, Data: make([]byte, 128)}, &PingFrame{}, &ResetStreamFrame{StreamID: 87654, ErrorCode: 1234, FinalSize: 1e8}, } benchmarkFrames(b, frames...) } func BenchmarkParseAckFrame(b *testing.B) { var frames []Frame for i := range 10 { frames = append(frames, &AckFrame{ AckRanges: []AckRange{ {Smallest: protocol.PacketNumber(5000 + i), Largest: protocol.PacketNumber(5200 + i)}, {Smallest: protocol.PacketNumber(1 + i), Largest: protocol.PacketNumber(4200 + i)}, }, DelayTime: time.Duration(int64(time.Millisecond) * int64(i)), ECT0: uint64(5000 + i), ECT1: uint64(i), ECNCE: uint64(10 + i), }) } benchmarkFrames(b, frames...) } func BenchmarkParseStreamFrame(b *testing.B) { var frames []Frame for i := range 10 { data := make([]byte, 200+i) rand.Read(data) frames = append(frames, &StreamFrame{ StreamID: protocol.StreamID(1337 + i), Offset: protocol.ByteCount(1e7 + i), Data: data, DataLenPresent: true, }) } benchmarkFrames(b, frames...) } func BenchmarkParseDatagramFrame(b *testing.B) { var frames []Frame for i := range 10 { data := make([]byte, 200+i) rand.Read(data) frames = append(frames, &DatagramFrame{ Data: data, DataLenPresent: true, }) } benchmarkFrames(b, frames...) } func benchmarkFrames(b *testing.B, frames ...Frame) { b.ReportAllocs() buf := writeFrames(b, frames...) parser := NewFrameParser(true, true, true) parser.SetAckDelayExponent(3) for b.Loop() { parseFrames(b, parser, buf, frames...) } } quic-go-0.59.0/internal/wire/frame_test.go000066400000000000000000000017771513066070600204330ustar00rootroot00000000000000package wire import ( "testing" "github.com/stretchr/testify/require" ) func TestProbingFrames(t *testing.T) { testCases := map[Frame]bool{ &AckFrame{}: false, &ConnectionCloseFrame{}: false, &DataBlockedFrame{}: false, &PingFrame{}: false, &ResetStreamFrame{}: false, &StreamFrame{}: false, &DatagramFrame{}: false, &MaxDataFrame{}: false, &MaxStreamDataFrame{}: false, &StopSendingFrame{}: false, &PathChallengeFrame{}: true, &PathResponseFrame{}: true, &NewConnectionIDFrame{}: true, } for f, expected := range testCases { require.Equal(t, expected, IsProbingFrame(f)) } } func TestIsProbingFrameType(t *testing.T) { tests := map[FrameType]bool{ FrameTypePathChallenge: true, FrameTypePathResponse: true, FrameTypeNewConnectionID: true, FrameType(0x01): false, FrameType(0xFF): false, } for ft, expected := range tests { require.Equal(t, expected, IsProbingFrameType(ft)) } } quic-go-0.59.0/internal/wire/frame_type.go000066400000000000000000000051101513066070600204160ustar00rootroot00000000000000package wire import "github.com/quic-go/quic-go/internal/protocol" type FrameType uint64 // These constants correspond to those defined in RFC 9000. // Stream frame types are not listed explicitly here; use FrameType.IsStreamFrameType() to identify them. const ( FrameTypePing FrameType = 0x1 FrameTypeAck FrameType = 0x2 FrameTypeAckECN FrameType = 0x3 FrameTypeResetStream FrameType = 0x4 FrameTypeStopSending FrameType = 0x5 FrameTypeCrypto FrameType = 0x6 FrameTypeNewToken FrameType = 0x7 FrameTypeMaxData FrameType = 0x10 FrameTypeMaxStreamData FrameType = 0x11 FrameTypeBidiMaxStreams FrameType = 0x12 FrameTypeUniMaxStreams FrameType = 0x13 FrameTypeDataBlocked FrameType = 0x14 FrameTypeStreamDataBlocked FrameType = 0x15 FrameTypeBidiStreamBlocked FrameType = 0x16 FrameTypeUniStreamBlocked FrameType = 0x17 FrameTypeNewConnectionID FrameType = 0x18 FrameTypeRetireConnectionID FrameType = 0x19 FrameTypePathChallenge FrameType = 0x1a FrameTypePathResponse FrameType = 0x1b FrameTypeConnectionClose FrameType = 0x1c FrameTypeApplicationClose FrameType = 0x1d FrameTypeHandshakeDone FrameType = 0x1e // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/07/ FrameTypeResetStreamAt FrameType = 0x24 // https://datatracker.ietf.org/doc/draft-ietf-quic-ack-frequency/11/ FrameTypeAckFrequency FrameType = 0xaf FrameTypeImmediateAck FrameType = 0x1f FrameTypeDatagramNoLength FrameType = 0x30 FrameTypeDatagramWithLength FrameType = 0x31 ) func (t FrameType) IsStreamFrameType() bool { return t >= 0x8 && t <= 0xf } func (t FrameType) isValidRFC9000() bool { return t <= 0x1e } func (t FrameType) IsAckFrameType() bool { return t == FrameTypeAck || t == FrameTypeAckECN } func (t FrameType) IsDatagramFrameType() bool { return t == FrameTypeDatagramNoLength || t == FrameTypeDatagramWithLength } func (t FrameType) isAllowedAtEncLevel(encLevel protocol.EncryptionLevel) bool { //nolint:exhaustive switch encLevel { case protocol.EncryptionInitial, protocol.EncryptionHandshake: switch t { case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypePing: return true default: return false } case protocol.Encryption0RTT: switch t { case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypeNewToken, FrameTypePathResponse, FrameTypeRetireConnectionID: return false default: return true } case protocol.Encryption1RTT: return true default: panic("unknown encryption level") } } quic-go-0.59.0/internal/wire/frame_type_test.go000066400000000000000000000024241513066070600214620ustar00rootroot00000000000000package wire import ( "testing" "github.com/stretchr/testify/require" ) func TestIsStreamFrameType(t *testing.T) { for i := 0x08; i <= 0x0f; i++ { require.Truef(t, FrameType(i).IsStreamFrameType(), "FrameType(0x%x).IsStreamFrameType() = false, want true", i) } require.False(t, FrameType(0x1).IsStreamFrameType()) } func TestIsAckFrameType(t *testing.T) { require.True(t, FrameTypeAck.IsAckFrameType(), "AckFrameType should be recognized as ACK") require.True(t, FrameTypeAckECN.IsAckFrameType(), "AckECNFrameType should be recognized as ACK") require.False(t, FrameTypePing.IsAckFrameType(), "PingFrameType should not be recognized as ACK") require.False(t, FrameType(0x10).IsAckFrameType(), "MaxDataFrameType should not be recognized as ACK") } func TestIsDatagramFrameType(t *testing.T) { require.True(t, FrameTypeDatagramNoLength.IsDatagramFrameType(), "DatagramNoLengthFrameType should be recognized as DATAGRAM") require.True(t, FrameTypeDatagramWithLength.IsDatagramFrameType(), "DatagramWithLengthFrameType should be recognized as DATAGRAM") require.False(t, FrameTypePing.IsDatagramFrameType(), "PingFrameType should not be recognized as DATAGRAM") require.False(t, FrameType(0x1e).IsDatagramFrameType(), "HandshakeDoneFrameType should not be recognized as DATAGRAM") } quic-go-0.59.0/internal/wire/handshake_done_frame.go000066400000000000000000000006401513066070600223730ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" ) // A HandshakeDoneFrame is a HANDSHAKE_DONE frame type HandshakeDoneFrame struct{} func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return append(b, byte(FrameTypeHandshakeDone)), nil } // Length of a written frame func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 } quic-go-0.59.0/internal/wire/handshake_done_frame_test.go000066400000000000000000000006431513066070600234350ustar00rootroot00000000000000package wire import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestWriteHandshakeDoneSampleFrame(t *testing.T) { frame := HandshakeDoneFrame{} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, []byte{byte(FrameTypeHandshakeDone)}, b) require.Equal(t, protocol.ByteCount(1), frame.Length(protocol.Version1)) } quic-go-0.59.0/internal/wire/header.go000066400000000000000000000214431513066070600175220ustar00rootroot00000000000000package wire import ( "encoding/binary" "errors" "fmt" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // ParseConnectionID parses the destination connection ID of a packet. func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { if len(data) == 0 { return protocol.ConnectionID{}, io.EOF } if !IsLongHeaderPacket(data[0]) { if len(data) < shortHeaderConnIDLen+1 { return protocol.ConnectionID{}, io.EOF } return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil } if len(data) < 6 { return protocol.ConnectionID{}, io.EOF } destConnIDLen := int(data[5]) if destConnIDLen > protocol.MaxConnIDLen { return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen } if len(data) < 6+destConnIDLen { return protocol.ConnectionID{}, io.EOF } return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil } // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet, // using only the version-independent packet format as described in Section 5.1 of RFC 8999: // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1. // This function should only be called on Long Header packets for which we don't support the version. func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) { startLen := len(data) if len(data) < 6 { return 0, nil, nil, io.EOF } data = data[5:] // skip first byte and version field destConnIDLen := data[0] data = data[1:] destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen) if len(data) < int(destConnIDLen)+1 { return 0, nil, nil, io.EOF } copy(destConnID, data) data = data[destConnIDLen:] srcConnIDLen := data[0] data = data[1:] if len(data) < int(srcConnIDLen) { return 0, nil, nil, io.EOF } srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen) copy(srcConnID, data) return startLen - len(data) + int(srcConnIDLen), destConnID, srcConnID, nil } func IsPotentialQUICPacket(firstByte byte) bool { return firstByte&0x40 > 0 } // IsLongHeaderPacket says if this is a Long Header packet func IsLongHeaderPacket(firstByte byte) bool { return firstByte&0x80 > 0 } // ParseVersion parses the QUIC version. // It should only be called for Long Header packets (Short Header packets don't contain a version number). func ParseVersion(data []byte) (protocol.Version, error) { if len(data) < 5 { return 0, io.EOF } return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil } // IsVersionNegotiationPacket says if this is a version negotiation packet func IsVersionNegotiationPacket(b []byte) bool { if len(b) < 5 { return false } return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 } // Is0RTTPacket says if this is a 0-RTT packet. // A packet sent with a version we don't understand can never be a 0-RTT packet. func Is0RTTPacket(b []byte) bool { if len(b) < 5 { return false } if !IsLongHeaderPacket(b[0]) { return false } version := protocol.Version(binary.BigEndian.Uint32(b[1:5])) //nolint:exhaustive // We only need to test QUIC versions that we support. switch version { case protocol.Version1: return b[0]>>4&0b11 == 0b01 case protocol.Version2: return b[0]>>4&0b11 == 0b10 default: return false } } var ErrUnsupportedVersion = errors.New("unsupported version") // The Header is the version independent part of the header type Header struct { typeByte byte Type protocol.PacketType Version protocol.Version SrcConnectionID protocol.ConnectionID DestConnectionID protocol.ConnectionID Length protocol.ByteCount Token []byte parsedLen protocol.ByteCount // how many bytes were read while parsing this header } // ParsePacket parses a long header packet. // The packet is cut according to the length field. // If we understand the version, the packet is parsed up unto the packet number. // Otherwise, only the invariant part of the header is parsed. func ParsePacket(data []byte) (*Header, []byte, []byte, error) { if len(data) == 0 || !IsLongHeaderPacket(data[0]) { return nil, nil, nil, errors.New("not a long header packet") } hdr, err := parseHeader(data) if err != nil { if errors.Is(err, ErrUnsupportedVersion) { return hdr, nil, nil, err } return nil, nil, nil, err } if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) } packetLen := int(hdr.ParsedLen() + hdr.Length) return hdr, data[:packetLen], data[packetLen:], nil } // ParseHeader parses the header: // * if we understand the version: up to the packet number // * if not, only the invariant part of the header func parseHeader(b []byte) (*Header, error) { if len(b) == 0 { return nil, io.EOF } typeByte := b[0] h := &Header{typeByte: typeByte} l, err := h.parseLongHeader(b[1:]) h.parsedLen = protocol.ByteCount(l) + 1 return h, err } func (h *Header) parseLongHeader(b []byte) (int, error) { startLen := len(b) if len(b) < 5 { return 0, io.EOF } h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4])) if h.Version != 0 && h.typeByte&0x40 == 0 { return startLen - len(b), errors.New("not a QUIC packet") } destConnIDLen := int(b[4]) if destConnIDLen > protocol.MaxConnIDLen { return startLen - len(b), protocol.ErrInvalidConnectionIDLen } b = b[5:] if len(b) < destConnIDLen+1 { return startLen - len(b), io.EOF } h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen]) srcConnIDLen := int(b[destConnIDLen]) if srcConnIDLen > protocol.MaxConnIDLen { return startLen - len(b), protocol.ErrInvalidConnectionIDLen } b = b[destConnIDLen+1:] if len(b) < srcConnIDLen { return startLen - len(b), io.EOF } h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen]) b = b[srcConnIDLen:] if h.Version == 0 { // version negotiation packet return startLen - len(b), nil } // If we don't understand the version, we have no idea how to interpret the rest of the bytes if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { return startLen - len(b), ErrUnsupportedVersion } if h.Version == protocol.Version2 { switch h.typeByte >> 4 & 0b11 { case 0b00: h.Type = protocol.PacketTypeRetry case 0b01: h.Type = protocol.PacketTypeInitial case 0b10: h.Type = protocol.PacketType0RTT case 0b11: h.Type = protocol.PacketTypeHandshake } } else { switch h.typeByte >> 4 & 0b11 { case 0b00: h.Type = protocol.PacketTypeInitial case 0b01: h.Type = protocol.PacketType0RTT case 0b10: h.Type = protocol.PacketTypeHandshake case 0b11: h.Type = protocol.PacketTypeRetry } } if h.Type == protocol.PacketTypeRetry { tokenLen := len(b) - 16 if tokenLen <= 0 { return startLen - len(b), io.EOF } h.Token = make([]byte, tokenLen) copy(h.Token, b[:tokenLen]) return startLen - len(b) + tokenLen + 16, nil } if h.Type == protocol.PacketTypeInitial { tokenLen, n, err := quicvarint.Parse(b) if err != nil { return startLen - len(b), err } b = b[n:] if tokenLen > uint64(len(b)) { return startLen - len(b), io.EOF } h.Token = make([]byte, tokenLen) copy(h.Token, b[:tokenLen]) b = b[tokenLen:] } pl, n, err := quicvarint.Parse(b) if err != nil { return 0, err } h.Length = protocol.ByteCount(pl) return startLen - len(b) + n, nil } // ParsedLen returns the number of bytes that were consumed when parsing the header func (h *Header) ParsedLen() protocol.ByteCount { return h.parsedLen } // ParseExtended parses the version dependent part of the header. // The Reader has to be set such that it points to the first byte of the header. func (h *Header) ParseExtended(data []byte) (*ExtendedHeader, error) { extHdr := h.toExtendedHeader() reservedBitsValid, err := extHdr.parse(data) if err != nil { return nil, err } if !reservedBitsValid { return extHdr, ErrInvalidReservedBits } return extHdr, nil } func (h *Header) toExtendedHeader() *ExtendedHeader { return &ExtendedHeader{Header: *h} } // PacketType is the type of the packet, for logging purposes func (h *Header) PacketType() string { return h.Type.String() } func readPacketNumber(data []byte, pnLen protocol.PacketNumberLen) (protocol.PacketNumber, error) { var pn protocol.PacketNumber switch pnLen { case protocol.PacketNumberLen1: pn = protocol.PacketNumber(data[0]) case protocol.PacketNumberLen2: pn = protocol.PacketNumber(binary.BigEndian.Uint16(data[:2])) case protocol.PacketNumberLen3: pn = protocol.PacketNumber(uint32(data[2]) + uint32(data[1])<<8 + uint32(data[0])<<16) case protocol.PacketNumberLen4: pn = protocol.PacketNumber(binary.BigEndian.Uint32(data[:4])) default: return 0, fmt.Errorf("invalid packet number length: %d", pnLen) } return pn, nil } quic-go-0.59.0/internal/wire/header_test.go000066400000000000000000000517061513066070600205660ustar00rootroot00000000000000package wire import ( "bytes" "crypto/rand" "encoding/binary" "io" mrand "math/rand/v2" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseConnIDLongHeaderPacket(t *testing.T) { b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}), Version: protocol.Version1, }, PacketNumberLen: 2, }).Append(nil, protocol.Version1) require.NoError(t, err) connID, err := ParseConnectionID(b, 8) require.NoError(t, err) require.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), connID) } func TestParseConnIDTooLong(t *testing.T) { b := []byte{0x80, 0, 0, 0, 0} binary.BigEndian.PutUint32(b[1:], uint32(protocol.Version1)) b = append(b, 21) // dest conn id len b = append(b, make([]byte, 21)...) _, err := ParseConnectionID(b, 4) require.Error(t, err) require.ErrorIs(t, err, protocol.ErrInvalidConnectionIDLen) } func TestParseConnIDEOFLongHeader(t *testing.T) { b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}), SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 8, 9}), Version: protocol.Version1, }, PacketNumberLen: 2, }).Append(nil, protocol.Version1) require.NoError(t, err) data := b[:len(b)-2] // cut the packet number _, err = ParseConnectionID(data, 8) require.NoError(t, err) for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { b := make([]byte, i) copy(b, data[:i]) _, err := ParseConnectionID(b, 8) require.Error(t, err) require.ErrorIs(t, err, io.EOF) } } func TestIs0RTT(t *testing.T) { t.Run("QUIC v1", func(t *testing.T) { zeroRTTHeader := make([]byte, 5) zeroRTTHeader[0] = 0x80 | 0b01<<4 binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version1)) require.True(t, Is0RTTPacket(zeroRTTHeader)) require.False(t, Is0RTTPacket(zeroRTTHeader[:4])) // too short require.False(t, Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})) // unknown version require.False(t, Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})) // short header require.True(t, Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))) }) t.Run("QUIC v2", func(t *testing.T) { zeroRTTHeader := make([]byte, 5) zeroRTTHeader[0] = 0x80 | 0b10<<4 binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version2)) require.True(t, Is0RTTPacket(zeroRTTHeader)) require.False(t, Is0RTTPacket(zeroRTTHeader[:4])) // too short require.False(t, Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})) // unknown version require.False(t, Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})) // short header require.True(t, Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))) }) } func TestParseVersion(t *testing.T) { b := []byte{0x80, 0xde, 0xad, 0xbe, 0xef} v, err := ParseVersion(b) require.NoError(t, err) require.Equal(t, protocol.Version(0xdeadbeef), v) for i := range b { _, err := ParseVersion(b[:i]) require.ErrorIs(t, err, io.EOF) } } func TestParseArbitraryLengthConnectionIDs(t *testing.T) { generateConnID := func(l int) protocol.ArbitraryLenConnectionID { c := make(protocol.ArbitraryLenConnectionID, l) rand.Read(c) return c } src := generateConnID(mrand.IntN(255) + 1) dest := generateConnID(mrand.IntN(255) + 1) b := []byte{0x80, 1, 2, 3, 4} b = append(b, uint8(dest.Len())) b = append(b, dest.Bytes()...) b = append(b, uint8(src.Len())) b = append(b, src.Bytes()...) l := len(b) b = append(b, []byte("foobar")...) // add some payload parsed, d, s, err := ParseArbitraryLenConnectionIDs(b) require.Equal(t, l, parsed) require.NoError(t, err) require.Equal(t, src, s) require.Equal(t, dest, d) for i := range b[:l] { _, _, _, err := ParseArbitraryLenConnectionIDs(b[:i]) require.ErrorIs(t, err, io.EOF) } } func TestIdentifyVersionNegotiationPackets(t *testing.T) { require.True(t, IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})) require.False(t, IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})) require.False(t, IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})) require.False(t, IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})) require.False(t, IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})) require.False(t, IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})) } func TestVersionNegotiationPacketEOF(t *testing.T) { vnp := []byte{0x80, 0, 0, 0, 0} for i := range vnp { require.False(t, IsVersionNegotiationPacket(vnp[:i])) } } func TestParseLongHeader(t *testing.T) { destConnID := protocol.ParseConnectionID([]byte{9, 8, 7, 6, 5, 4, 3, 2, 1}) srcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) data := []byte{0xc0 ^ 0x3} data = appendVersion(data, protocol.Version1) data = append(data, 0x9) // dest conn id length data = append(data, destConnID.Bytes()...) data = append(data, 0x4) // src conn id length data = append(data, srcConnID.Bytes()...) data = append(data, encodeVarInt(6)...) // token length data = append(data, []byte("foobar")...) // token data = append(data, encodeVarInt(10)...) // length hdrLen := len(data) data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number data = append(data, []byte("foobar")...) require.False(t, IsVersionNegotiationPacket(data)) hdr, pdata, rest, err := ParsePacket(data) require.NoError(t, err) require.Equal(t, data, pdata) require.Equal(t, destConnID, hdr.DestConnectionID) require.Equal(t, srcConnID, hdr.SrcConnectionID) require.Equal(t, protocol.PacketTypeInitial, hdr.Type) require.Equal(t, []byte("foobar"), hdr.Token) require.Equal(t, protocol.ByteCount(10), hdr.Length) require.Equal(t, protocol.Version1, hdr.Version) require.Empty(t, rest) extHdr, err := hdr.ParseExtended(data) require.NoError(t, err) require.Equal(t, protocol.PacketNumberLen4, extHdr.PacketNumberLen) require.Equal(t, protocol.PacketNumber(0xbeef), extHdr.PacketNumber) require.Equal(t, hdrLen, int(hdr.ParsedLen())) require.Equal(t, hdr.ParsedLen()+4, extHdr.ParsedLen()) } func TestErrorIfReservedBitNotSet(t *testing.T) { data := []byte{ 0x80 | 0x2<<4, 0x11, // connection ID lengths 0xde, 0xca, 0xfb, 0xad, // dest conn ID 0xde, 0xad, 0xbe, 0xef, // src conn ID } _, _, _, err := ParsePacket(data) require.EqualError(t, err, "not a QUIC packet") } func TestStopParsingWhenEncounteringUnsupportedVersion(t *testing.T) { data := []byte{ 0xc0, 0xde, 0xad, 0xbe, 0xef, 0x8, // dest conn ID len 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, // dest conn ID 0x8, // src conn ID len 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes } hdr, _, rest, err := ParsePacket(data) require.EqualError(t, err, ErrUnsupportedVersion.Error()) require.Equal(t, protocol.Version(0xdeadbeef), hdr.Version) require.Equal(t, protocol.ParseConnectionID([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}), hdr.DestConnectionID) require.Equal(t, protocol.ParseConnectionID([]byte{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1}), hdr.SrcConnectionID) require.Empty(t, rest) } func TestParseLongHeaderWithoutDestinationConnectionID(t *testing.T) { data := []byte{0xc0 ^ 0x1<<4} data = appendVersion(data, protocol.Version1) data = append(data, 0) // dest conn ID len data = append(data, 4) // src conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data) require.NoError(t, err) require.Equal(t, protocol.PacketType0RTT, hdr.Type) require.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), hdr.SrcConnectionID) require.Zero(t, hdr.DestConnectionID) } func TestParseLongHeaderWithoutSourceConnectionID(t *testing.T) { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, protocol.Version1) data = append(data, 10) // dest conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID data = append(data, 0) // src conn ID len data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data) require.NoError(t, err) require.Zero(t, hdr.SrcConnectionID) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), hdr.DestConnectionID) } func TestErrorOnTooLongDestinationConnectionID(t *testing.T) { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, protocol.Version1) data = append(data, 21) // dest conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // dest connection ID data = append(data, 0x0) // src conn ID len data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) _, _, _, err := ParsePacket(data) require.EqualError(t, err, protocol.ErrInvalidConnectionIDLen.Error()) } func TestParseLongHeaderWith2BytePacketNumber(t *testing.T) { data := []byte{0xc0 ^ 0x1} data = appendVersion(data, protocol.Version1) // version number data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(0)...) // token length data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0x1, 0x23}...) hdr, _, _, err := ParsePacket(data) require.NoError(t, err) extHdr, err := hdr.ParseExtended(data) require.NoError(t, err) require.Equal(t, protocol.PacketNumber(0x123), extHdr.PacketNumber) require.Equal(t, protocol.PacketNumberLen2, extHdr.PacketNumberLen) require.Equal(t, len(data), int(extHdr.ParsedLen())) } func TestParseRetryPacket(t *testing.T) { for _, version := range []protocol.Version{protocol.Version1, protocol.Version2} { t.Run(version.String(), func(t *testing.T) { var packetType byte if version == protocol.Version1 { packetType = 0b11 << 4 } else { packetType = 0b00 << 4 } data := []byte{0xc0 | packetType | (10 - 3) /* connection ID length */} data = appendVersion(data, version) data = append(data, []byte{6}...) // dest conn ID len data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID data = append(data, []byte{10}...) // src conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) hdr, pdata, rest, err := ParsePacket(data) require.NoError(t, err) require.Equal(t, protocol.PacketTypeRetry, hdr.Type) require.Equal(t, version, hdr.Version) require.Equal(t, protocol.ParseConnectionID([]byte{6, 5, 4, 3, 2, 1}), hdr.DestConnectionID) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), hdr.SrcConnectionID) require.Equal(t, []byte("foobar"), hdr.Token) require.Equal(t, data, pdata) require.Empty(t, rest) }) } } func TestRetryPacketTooShortForIntegrityTag(t *testing.T) { data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} data = appendVersion(data, protocol.Version1) data = append(data, []byte{0, 0}...) // conn ID lens data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // this results in a token length of 0 _, _, _, err := ParsePacket(data) require.Equal(t, io.EOF, err) } func TestTokenLengthTooLarge(t *testing.T) { data := []byte{0xc0 ^ 0x1} data = appendVersion(data, protocol.Version1) data = append(data, 0x0) // connection ID lengths data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) data = append(data, encodeVarInt(0x42)...) // length, 1 byte data = append(data, []byte{0x12, 0x34}...) // packet number _, _, _, err := ParsePacket(data) require.Equal(t, io.EOF, err) } func TestErrorOn5thOr6thBitSet(t *testing.T) { data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} data = appendVersion(data, protocol.Version1) data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(2)...) // length data = append(data, []byte{0x12, 0x34}...) // packet number hdr, _, _, err := ParsePacket(data) require.NoError(t, err) require.Equal(t, protocol.PacketTypeHandshake, hdr.Type) extHdr, err := hdr.ParseExtended(data) require.EqualError(t, err, ErrInvalidReservedBits.Error()) require.NotNil(t, extHdr) require.Equal(t, protocol.PacketNumber(0x1234), extHdr.PacketNumber) } func TestHeaderEOF(t *testing.T) { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, protocol.Version1) data = append(data, 0x8) // dest conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID data = append(data, 0x8) // src conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID for i := 1; i < len(data); i++ { _, _, _, err := ParsePacket(data[:i]) require.Equal(t, io.EOF, err) } } func TestParseExtendedHeaderEOF(t *testing.T) { data := []byte{0xc0 | 0x2<<4 | 0x3} data = appendVersion(data, protocol.Version1) data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(0)...) // length hdrLen := len(data) data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number for i := hdrLen; i < len(data); i++ { b := data[:i] hdr, _, _, err := ParsePacket(b) require.NoError(t, err) _, err = hdr.ParseExtended(b) require.Equal(t, io.EOF, err) } } func TestParseRetryEOF(t *testing.T) { data := []byte{0xc0 ^ 0x3<<4} data = appendVersion(data, protocol.Version1) data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, 0xa) // Orig Destination Connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID hdrLen := len(data) for i := hdrLen; i < len(data); i++ { data = data[:i] hdr, _, _, err := ParsePacket(data) require.NoError(t, err) _, err = hdr.ParseExtended(data) require.Equal(t, io.EOF, err) } } func TestCoalescedPacketParsing(t *testing.T) { hdr := Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 2 + 6, Version: protocol.Version1, } b, err := (&ExtendedHeader{ Header: hdr, PacketNumber: 0x1337, PacketNumberLen: 2, }).Append(nil, protocol.Version1) require.NoError(t, err) hdrRaw := append([]byte{}, b...) b = append(b, []byte("foobar")...) // payload of the first packet b = append(b, []byte("raboof")...) // second packet parsedHdr, data, rest, err := ParsePacket(b) require.NoError(t, err) require.Equal(t, hdr.Type, parsedHdr.Type) require.Equal(t, hdr.DestConnectionID, parsedHdr.DestConnectionID) require.Equal(t, append(hdrRaw, []byte("foobar")...), data) require.Equal(t, []byte("raboof"), rest) } func TestCoalescedPacketErrorOnTooSmallPacketNumber(t *testing.T) { b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 3, Version: protocol.Version1, }, PacketNumber: 0x1337, PacketNumberLen: 2, }).Append(nil, protocol.Version1) require.NoError(t, err) _, _, _, err = ParsePacket(b) require.Error(t, err) require.Contains(t, err.Error(), "packet length (2 bytes) is smaller than the expected length (3 bytes)") } func TestCoalescedPacketErrorOnTooSmallPayload(t *testing.T) { b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 1000, Version: protocol.Version1, }, PacketNumber: 0x1337, PacketNumberLen: 2, }).Append(nil, protocol.Version1) require.NoError(t, err) b = append(b, make([]byte, 500-2 /* for packet number length */)...) _, _, _, err = ParsePacket(b) require.EqualError(t, err, "packet length (500 bytes) is smaller than the expected length (1000 bytes)") } func TestDistinguishesLongAndShortHeaderPackets(t *testing.T) { require.False(t, IsLongHeaderPacket(0x40)) require.True(t, IsLongHeaderPacket(0x80^0x40^0x12)) } func TestPacketTypeForLogging(t *testing.T) { require.Equal(t, "Initial", (&Header{Type: protocol.PacketTypeInitial}).PacketType()) require.Equal(t, "Handshake", (&Header{Type: protocol.PacketTypeHandshake}).PacketType()) } func BenchmarkIs0RTTPacket(b *testing.B) { src := mrand.NewChaCha8([32]byte{'f', 'o', 'o', 'b', 'a', 'r'}) random := mrand.New(src) packets := make([][]byte, 1024) for i := range len(packets) { packets[i] = make([]byte, random.IntN(256)) src.Read(packets[i]) } var i int for b.Loop() { Is0RTTPacket(packets[i%len(packets)]) i++ } } func BenchmarkParseInitial(b *testing.B) { b.Run("without token", func(b *testing.B) { benchmarkInitialPacketParsing(b, nil) }) b.Run("with token", func(b *testing.B) { token := make([]byte, 32) rand.Read(token) benchmarkInitialPacketParsing(b, token) }) } func benchmarkInitialPacketParsing(b *testing.B, token []byte) { b.ReportAllocs() hdr := Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), SrcConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), Length: 1000, Token: token, Version: protocol.Version1, } data, err := (&ExtendedHeader{ Header: hdr, PacketNumber: 0x1337, PacketNumberLen: 4, }).Append(nil, protocol.Version1) if err != nil { b.Fatal(err) } data = append(data, make([]byte, 1000)...) for b.Loop() { h, _, _, err := ParsePacket(data) if err != nil { b.Fatal(err) } if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID || !bytes.Equal(h.Token, hdr.Token) { b.Fatalf("headers don't match: %v vs %v", h, hdr) } } } func BenchmarkParseRetry(b *testing.B) { b.ReportAllocs() token := make([]byte, 64) rand.Read(token) hdr := &ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeRetry, SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), Token: token, Version: protocol.Version1, }, } data, err := hdr.Append(nil, hdr.Version) if err != nil { b.Fatal(err) } for b.Loop() { h, _, _, err := ParsePacket(data) if err != nil { b.Fatal(err) } if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID || !bytes.Equal(h.Token, hdr.Token[:len(hdr.Token)-16]) { b.Fatalf("headers don't match: %#v vs %#v", h, hdr) } } } func BenchmarkArbitraryHeaderParsing(b *testing.B) { b.Run("dest 8/ src 10", func(b *testing.B) { benchmarkArbitraryHeaderParsing(b, 8, 10) }) b.Run("dest 20 / src 20", func(b *testing.B) { benchmarkArbitraryHeaderParsing(b, 20, 20) }) b.Run("dest 100 / src 150", func(b *testing.B) { benchmarkArbitraryHeaderParsing(b, 100, 150) }) } func benchmarkArbitraryHeaderParsing(b *testing.B, destLen, srcLen int) { destConnID := make([]byte, destLen) rand.Read(destConnID) srcConnID := make([]byte, srcLen) rand.Read(srcConnID) buf := []byte{0x80, 1, 2, 3, 4} buf = append(buf, uint8(destLen)) buf = append(buf, destConnID...) buf = append(buf, uint8(srcLen)) buf = append(buf, srcConnID...) b.ReportAllocs() for b.Loop() { parsed, d, s, err := ParseArbitraryLenConnectionIDs(buf) if err != nil { b.Fatal(err) } if parsed != len(buf) { b.Fatal("expected to parse entire slice") } if !bytes.Equal(destConnID, d.Bytes()) { b.Fatalf("destination connection IDs don't match: %v vs %v", destConnID, d.Bytes()) } if !bytes.Equal(srcConnID, s.Bytes()) { b.Fatalf("source connection IDs don't match: %v vs %v", srcConnID, s.Bytes()) } } } quic-go-0.59.0/internal/wire/immediate_ack_frame.go000066400000000000000000000010221513066070600222070ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // An ImmediateAckFrame is an IMMEDIATE_ACK frame type ImmediateAckFrame struct{} func (f *ImmediateAckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return quicvarint.Append(b, uint64(FrameTypeImmediateAck)), nil } // Length of a written frame func (f *ImmediateAckFrame) Length(_ protocol.Version) protocol.ByteCount { return protocol.ByteCount(quicvarint.Len(uint64(FrameTypeImmediateAck))) } quic-go-0.59.0/internal/wire/immediate_ack_frame_test.go000066400000000000000000000010051513066070600232470ustar00rootroot00000000000000package wire import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestImmediateAckFrame(t *testing.T) { frame := ImmediateAckFrame{} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) val, l, err := quicvarint.Parse(b) require.NoError(t, err) require.Equal(t, uint64(FrameTypeImmediateAck), val) require.Equal(t, len(b), l) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/log.go000066400000000000000000000064151513066070600170550ustar00rootroot00000000000000package wire import ( "fmt" "strings" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) // LogFrame logs a frame, either sent or received func LogFrame(logger utils.Logger, frame Frame, sent bool) { if !logger.Debug() { return } dir := "<-" if sent { dir = "->" } switch f := frame.(type) { case *CryptoFrame: dataLen := protocol.ByteCount(len(f.Data)) logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen) case *StreamFrame: logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen()) case *ResetStreamFrame: logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize) case *AckFrame: hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 var ecn string if hasECN { ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE) } if len(f.AckRanges) > 1 { ackRanges := make([]string, len(f.AckRanges)) for i, r := range f.AckRanges { ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest) } logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn) } else { logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn) } case *MaxDataFrame: logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData) case *MaxStreamDataFrame: logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) case *DataBlockedFrame: logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData) case *StreamDataBlockedFrame: logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) case *MaxStreamsFrame: switch f.Type { case protocol.StreamTypeUni: logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum) case protocol.StreamTypeBidi: logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum) } case *StreamsBlockedFrame: switch f.Type { case protocol.StreamTypeUni: logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit) case protocol.StreamTypeBidi: logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) } case *NewConnectionIDFrame: logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken) case *RetireConnectionIDFrame: logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber) case *NewTokenFrame: logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) default: logger.Debugf("\t%s %#v", dir, frame) } } quic-go-0.59.0/internal/wire/log_test.go000066400000000000000000000134621513066070600201140ustar00rootroot00000000000000package wire import ( "bytes" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/stretchr/testify/require" ) func TestLogFrameNoDebug(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) logger.SetLogLevel(utils.LogLevelInfo) LogFrame(logger, &ResetStreamFrame{}, true) require.Zero(t, buf.Len()) } func TestLogSentFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) LogFrame(logger, &ResetStreamFrame{}, true) require.Contains(t, buf.String(), "\t-> &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n") } func TestLogReceivedFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) LogFrame(logger, &ResetStreamFrame{}, false) require.Contains(t, buf.String(), "\t<- &wire.ResetStreamFrame{StreamID: 0, ErrorCode: 0x0, FinalSize: 0}\n") } func TestLogCryptoFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &CryptoFrame{ Offset: 42, Data: make([]byte, 123), } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.CryptoFrame{Offset: 42, Data length: 123, Offset + Data length: 165}\n") } func TestLogStreamFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &StreamFrame{ StreamID: 42, Offset: 1337, Data: bytes.Repeat([]byte{'f'}, 100), } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.StreamFrame{StreamID: 42, Fin: false, Offset: 1337, Data length: 100, Offset + Data length: 1437}\n") } func TestLogAckFrameWithoutMissingPackets(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &AckFrame{ AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, DelayTime: 1 * time.Millisecond, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms}\n") } func TestLogAckFrameWithECN(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &AckFrame{ AckRanges: []AckRange{{Smallest: 42, Largest: 1337}}, DelayTime: 1 * time.Millisecond, ECT0: 5, ECT1: 66, ECNCE: 777, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.AckFrame{LargestAcked: 1337, LowestAcked: 42, DelayTime: 1ms, ECT0: 5, ECT1: 66, CE: 777}\n") } func TestLogAckFrameWithMissingPackets(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &AckFrame{ AckRanges: []AckRange{ {Smallest: 5, Largest: 8}, {Smallest: 2, Largest: 3}, }, DelayTime: 12 * time.Millisecond, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.AckFrame{LargestAcked: 8, LowestAcked: 2, AckRanges: {{Largest: 8, Smallest: 5}, {Largest: 3, Smallest: 2}}, DelayTime: 12ms}\n") } func TestLogMaxStreamsFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreamNum: 42, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: 42}\n") } func TestLogMaxDataFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &MaxDataFrame{ MaximumData: 42, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.MaxDataFrame{MaximumData: 42}\n") } func TestLogMaxStreamDataFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &MaxStreamDataFrame{ StreamID: 10, MaximumStreamData: 42, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.MaxStreamDataFrame{StreamID: 10, MaximumStreamData: 42}\n") } func TestLogDataBlockedFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &DataBlockedFrame{ MaximumData: 1000, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.DataBlockedFrame{MaximumData: 1000}\n") } func TestLogStreamDataBlockedFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &StreamDataBlockedFrame{ StreamID: 42, MaximumStreamData: 1000, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1000}\n") } func TestLogStreamsBlockedFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) frame := &StreamsBlockedFrame{ Type: protocol.StreamTypeBidi, StreamLimit: 42, } LogFrame(logger, frame, false) require.Contains(t, buf.String(), "\t<- &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: 42}\n") } func TestLogNewConnectionIDFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) LogFrame(logger, &NewConnectionIDFrame{ SequenceNumber: 42, RetirePriorTo: 24, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, }, false) require.Contains(t, buf.String(), "\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, RetirePriorTo: 24, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}") } func TestLogRetireConnectionIDFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) LogFrame(logger, &RetireConnectionIDFrame{SequenceNumber: 42}, false) require.Contains(t, buf.String(), "\t<- &wire.RetireConnectionIDFrame{SequenceNumber: 42}") } func TestLogNewTokenFrame(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) LogFrame(logger, &NewTokenFrame{ Token: []byte{0xde, 0xad, 0xbe, 0xef}, }, true) require.Contains(t, buf.String(), "\t-> &wire.NewTokenFrame{Token: 0xdeadbeef") } quic-go-0.59.0/internal/wire/max_data_frame.go000066400000000000000000000016541513066070600212240ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A MaxDataFrame carries flow control information for the connection type MaxDataFrame struct { MaximumData protocol.ByteCount } // parseMaxDataFrame parses a MAX_DATA frame func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) { frame := &MaxDataFrame{} byteOffset, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } frame.MaximumData = protocol.ByteCount(byteOffset) return frame, l, nil } func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeMaxData)) b = quicvarint.Append(b, uint64(f.MaximumData)) return b, nil } // Length of a written frame func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData))) } quic-go-0.59.0/internal/wire/max_data_frame_test.go000066400000000000000000000021421513066070600222540ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseMaxDataFrame(t *testing.T) { data := encodeVarInt(0xdecafbad123456) // byte offset frame, l, err := parseMaxDataFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.ByteCount(0xdecafbad123456), frame.MaximumData) require.Equal(t, len(data), l) } func TestParseMaxDataErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdecafbad1234567) // byte offset _, l, err := parseMaxDataFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseMaxDataFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteMaxDataFrame(t *testing.T) { f := &MaxDataFrame{MaximumData: 0xdeadbeefcafe} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeMaxData)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/max_stream_data_frame.go000066400000000000000000000023221513066070600225700ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A MaxStreamDataFrame is a MAX_STREAM_DATA frame type MaxStreamDataFrame struct { StreamID protocol.StreamID MaximumStreamData protocol.ByteCount } func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, int, error) { startLen := len(b) sid, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] offset, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] return &MaxStreamDataFrame{ StreamID: protocol.StreamID(sid), MaximumStreamData: protocol.ByteCount(offset), }, startLen - len(b), nil } func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeMaxStreamData)) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) return b, nil } // Length of a written frame func (f *MaxStreamDataFrame) Length(protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData))) } quic-go-0.59.0/internal/wire/max_stream_data_frame_test.go000066400000000000000000000027101513066070600236300ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseMaxStreamFrame(t *testing.T) { data := encodeVarInt(0xdeadbeef) // Stream ID data = append(data, encodeVarInt(0x12345678)...) // Offset frame, l, err := parseMaxStreamDataFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0xdeadbeef), frame.StreamID) require.Equal(t, protocol.ByteCount(0x12345678), frame.MaximumStreamData) require.Equal(t, len(data), l) } func TestParseMaxStreamDataErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdeadbeef) // Stream ID data = append(data, encodeVarInt(0x12345678)...) // Offset _, l, err := parseMaxStreamDataFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseMaxStreamDataFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteMaxStreamDataFrame(t *testing.T) { f := &MaxStreamDataFrame{ StreamID: 0xdecafbad, MaximumStreamData: 0xdeadbeefcafe42, } expected := []byte{byte(FrameTypeMaxStreamData)} expected = append(expected, encodeVarInt(0xdecafbad)...) expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, expected, b) require.Equal(t, len(b), int(f.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/max_streams_frame.go000066400000000000000000000026551513066070600217730ustar00rootroot00000000000000package wire import ( "fmt" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A MaxStreamsFrame is a MAX_STREAMS frame type MaxStreamsFrame struct { Type protocol.StreamType MaxStreamNum protocol.StreamNum } func parseMaxStreamsFrame(b []byte, typ FrameType, _ protocol.Version) (*MaxStreamsFrame, int, error) { f := &MaxStreamsFrame{} //nolint:exhaustive // Function will only be called with BidiMaxStreamsFrameType or UniMaxStreamsFrameType switch typ { case FrameTypeBidiMaxStreams: f.Type = protocol.StreamTypeBidi case FrameTypeUniMaxStreams: f.Type = protocol.StreamTypeUni } streamID, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } f.MaxStreamNum = protocol.StreamNum(streamID) if f.MaxStreamNum > protocol.MaxStreamCount { return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) } return f, l, nil } func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: b = append(b, byte(FrameTypeBidiMaxStreams)) case protocol.StreamTypeUni: b = append(b, byte(FrameTypeUniMaxStreams)) } b = quicvarint.Append(b, uint64(f.MaxStreamNum)) return b, nil } // Length of a written frame func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaxStreamNum))) } quic-go-0.59.0/internal/wire/max_streams_frame_test.go000066400000000000000000000070401513066070600230230ustar00rootroot00000000000000package wire import ( "fmt" "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestParseMaxStreamsFrameBidirectional(t *testing.T) { data := encodeVarInt(0xdecaf) f, l, err := parseMaxStreamsFrame(data, FrameTypeBidiMaxStreams, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeBidi, f.Type) require.EqualValues(t, 0xdecaf, f.MaxStreamNum) require.Equal(t, len(data), l) } func TestParseMaxStreamsFrameUnidirectional(t *testing.T) { data := encodeVarInt(0xdecaf) f, l, err := parseMaxStreamsFrame(data, FrameTypeUniMaxStreams, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeUni, f.Type) require.EqualValues(t, 0xdecaf, f.MaxStreamNum) require.Equal(t, len(data), l) } func TestParseMaxStreamsErrorsOnEOF(t *testing.T) { const typ = 0x1d data := encodeVarInt(0xdeadbeefcafe13) _, l, err := parseMaxStreamsFrame(data, typ, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseMaxStreamsFrame(data[:i], typ, protocol.Version1) require.Equal(t, io.EOF, err) } } func TestParseMaxStreamsMaxValue(t *testing.T) { for _, streamType := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { var streamTypeStr string if streamType == protocol.StreamTypeUni { streamTypeStr = "unidirectional" } else { streamTypeStr = "bidirectional" } t.Run(streamTypeStr, func(t *testing.T) { f := &MaxStreamsFrame{ Type: streamType, MaxStreamNum: protocol.MaxStreamCount, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] frame, _, err := parseMaxStreamsFrame(b, FrameType(typ), protocol.Version1) require.NoError(t, err) require.Equal(t, f, frame) }) } } func TestParseMaxStreamsErrorsOnTooLargeStreamCount(t *testing.T) { for _, streamType := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { var streamTypeStr string if streamType == protocol.StreamTypeUni { streamTypeStr = "unidirectional" } else { streamTypeStr = "bidirectional" } t.Run(streamTypeStr, func(t *testing.T) { f := &MaxStreamsFrame{ Type: streamType, MaxStreamNum: protocol.MaxStreamCount + 1, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] _, _, err = parseMaxStreamsFrame(b, FrameType(typ), protocol.Version1) require.EqualError(t, err, fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1)) }) } } func TestWriteMaxStreamsBidirectional(t *testing.T) { f := &MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreamNum: 0xdeadbeef, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeBidiMaxStreams)} expected = append(expected, encodeVarInt(0xdeadbeef)...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) } func TestWriteMaxStreamsUnidirectional(t *testing.T) { f := &MaxStreamsFrame{ Type: protocol.StreamTypeUni, MaxStreamNum: 0xdecafbad, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeUniMaxStreams)} expected = append(expected, encodeVarInt(0xdecafbad)...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/new_connection_id_frame.go000066400000000000000000000044531513066070600231320ustar00rootroot00000000000000package wire import ( "errors" "fmt" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame type NewConnectionIDFrame struct { SequenceNumber uint64 RetirePriorTo uint64 ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFrame, int, error) { startLen := len(b) seq, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] ret, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] if ret > seq { //nolint:staticcheck // SA1021: Retire Prior To is the name of the field return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) } if len(b) == 0 { return nil, 0, io.EOF } connIDLen := int(b[0]) b = b[1:] if connIDLen == 0 { return nil, 0, errors.New("invalid zero-length connection ID") } if connIDLen > protocol.MaxConnIDLen { return nil, 0, protocol.ErrInvalidConnectionIDLen } if len(b) < connIDLen { return nil, 0, io.EOF } frame := &NewConnectionIDFrame{ SequenceNumber: seq, RetirePriorTo: ret, ConnectionID: protocol.ParseConnectionID(b[:connIDLen]), } b = b[connIDLen:] if len(b) < len(frame.StatelessResetToken) { return nil, 0, io.EOF } copy(frame.StatelessResetToken[:], b) return frame, startLen - len(b) + len(frame.StatelessResetToken), nil } func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeNewConnectionID)) b = quicvarint.Append(b, f.SequenceNumber) b = quicvarint.Append(b, f.RetirePriorTo) connIDLen := f.ConnectionID.Len() if connIDLen > protocol.MaxConnIDLen { return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) } b = append(b, uint8(connIDLen)) b = append(b, f.ConnectionID.Bytes()...) b = append(b, f.StatelessResetToken[:]...) return b, nil } // Length of a written frame func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)+quicvarint.Len(f.RetirePriorTo)+1 /* connection ID length */ +f.ConnectionID.Len()) + 16 } quic-go-0.59.0/internal/wire/new_connection_id_frame_test.go000066400000000000000000000103041513066070600241610ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseNewConnectionIDFrame(t *testing.T) { data := encodeVarInt(0xdeadbeef) // sequence number data = append(data, encodeVarInt(0xcafe)...) // retire prior to data = append(data, 10) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token frame, l, err := parseNewConnectionIDFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, uint64(0xdeadbeef), frame.SequenceNumber) require.Equal(t, uint64(0xcafe), frame.RetirePriorTo) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), frame.ConnectionID) require.Equal(t, "deadbeefdecafbad", string(frame.StatelessResetToken[:])) require.Equal(t, len(data), l) } func TestParseNewConnectionIDRetirePriorToLargerThanSequenceNumber(t *testing.T) { data := encodeVarInt(1000) // sequence number data = append(data, encodeVarInt(1001)...) // retire prior to data = append(data, 3) data = append(data, []byte{1, 2, 3}...) data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token _, _, err := parseNewConnectionIDFrame(data, protocol.Version1) require.EqualError(t, err, "Retire Prior To value (1001) larger than Sequence Number (1000)") } func TestParseNewConnectionIDZeroLengthConnID(t *testing.T) { data := encodeVarInt(42) // sequence number data = append(data, encodeVarInt(12)...) // retire prior to data = append(data, 0) // connection ID length _, _, err := parseNewConnectionIDFrame(data, protocol.Version1) require.EqualError(t, err, "invalid zero-length connection ID") } func TestParseNewConnectionIDInvalidConnIDLength(t *testing.T) { data := encodeVarInt(0xdeadbeef) // sequence number data = append(data, encodeVarInt(0xcafe)...) // retire prior to data = append(data, 21) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token _, _, err := parseNewConnectionIDFrame(data, protocol.Version1) require.Equal(t, protocol.ErrInvalidConnectionIDLen, err) } func TestParseNewConnectionIDErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdeadbeef) // sequence number data = append(data, encodeVarInt(0xcafe1234)...) // retire prior to data = append(data, 10) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token _, l, err := parseNewConnectionIDFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseNewConnectionIDFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteNewConnectionIDFrame(t *testing.T) { token := protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} frame := &NewConnectionIDFrame{ SequenceNumber: 0x1337, RetirePriorTo: 0x42, ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}), StatelessResetToken: token, } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeNewConnectionID)} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0x42)...) expected = append(expected, 6) expected = append(expected, []byte{1, 2, 3, 4, 5, 6}...) expected = append(expected, token[:]...) require.Equal(t, expected, b) require.Equal(t, int(frame.Length(protocol.Version1)), len(b)) } quic-go-0.59.0/internal/wire/new_token_frame.go000066400000000000000000000020531513066070600214310ustar00rootroot00000000000000package wire import ( "errors" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A NewTokenFrame is a NEW_TOKEN frame type NewTokenFrame struct { Token []byte } func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, error) { tokenLen, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] if tokenLen == 0 { return nil, 0, errors.New("token must not be empty") } if uint64(len(b)) < tokenLen { return nil, 0, io.EOF } token := make([]byte, int(tokenLen)) copy(token, b) return &NewTokenFrame{Token: token}, l + int(tokenLen), nil } func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeNewToken)) b = quicvarint.Append(b, uint64(len(f.Token))) b = append(b, f.Token...) return b, nil } // Length of a written frame func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(len(f.Token)))+len(f.Token)) } quic-go-0.59.0/internal/wire/new_token_frame_test.go000066400000000000000000000034431513066070600224740ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseNewTokenFrame(t *testing.T) { token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." data := encodeVarInt(uint64(len(token))) data = append(data, token...) f, l, err := parseNewTokenFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, token, string(f.Token)) require.Equal(t, len(data), l) } func TestParseNewTokenFrameRejectsEmptyTokens(t *testing.T) { data := encodeVarInt(0) _, _, err := parseNewTokenFrame(data, protocol.Version1) require.EqualError(t, err, "token must not be empty") } func TestParseNewTokenFrameErrorsOnEOFs(t *testing.T) { token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit" data := encodeVarInt(uint64(len(token))) data = append(data, token...) _, l, err := parseNewTokenFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseNewTokenFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteNewTokenFrame(t *testing.T) { token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat." f := &NewTokenFrame{Token: []byte(token)} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeNewToken)} expected = append(expected, encodeVarInt(uint64(len(token)))...) expected = append(expected, token...) require.Equal(t, expected, b) require.Equal(t, len(b), int(f.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/path_challenge_frame.go000066400000000000000000000012571513066070600224030ustar00rootroot00000000000000package wire import ( "io" "github.com/quic-go/quic-go/internal/protocol" ) // A PathChallengeFrame is a PATH_CHALLENGE frame type PathChallengeFrame struct { Data [8]byte } func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, int, error) { f := &PathChallengeFrame{} if len(b) < 8 { return nil, 0, io.EOF } copy(f.Data[:], b) return f, 8, nil } func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypePathChallenge)) b = append(b, f.Data[:]...) return b, nil } // Length of a written frame func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + 8 } quic-go-0.59.0/internal/wire/path_challenge_frame_test.go000066400000000000000000000021251513066070600234350ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParsePathChallenge(t *testing.T) { b := []byte{1, 2, 3, 4, 5, 6, 7, 8} f, l, err := parsePathChallengeFrame(b, protocol.Version1) require.NoError(t, err) require.Equal(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, f.Data) require.Equal(t, len(b), l) } func TestParsePathChallengeErrorsOnEOFs(t *testing.T) { data := []byte{1, 2, 3, 4, 5, 6, 7, 8} _, l, err := parsePathChallengeFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parsePathChallengeFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWritePathChallenge(t *testing.T) { frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, []byte{byte(FrameTypePathChallenge), 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/path_response_frame.go000066400000000000000000000012461513066070600223150ustar00rootroot00000000000000package wire import ( "io" "github.com/quic-go/quic-go/internal/protocol" ) // A PathResponseFrame is a PATH_RESPONSE frame type PathResponseFrame struct { Data [8]byte } func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, int, error) { f := &PathResponseFrame{} if len(b) < 8 { return nil, 0, io.EOF } copy(f.Data[:], b) return f, 8, nil } func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypePathResponse)) b = append(b, f.Data[:]...) return b, nil } // Length of a written frame func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + 8 } quic-go-0.59.0/internal/wire/path_response_frame_test.go000066400000000000000000000021151513066070600233500ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParsePathResponse(t *testing.T) { b := []byte{1, 2, 3, 4, 5, 6, 7, 8} f, l, err := parsePathResponseFrame(b, protocol.Version1) require.NoError(t, err) require.Equal(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, f.Data) require.Equal(t, len(b), l) } func TestParsePathResponseErrorsOnEOFs(t *testing.T) { data := []byte{1, 2, 3, 4, 5, 6, 7, 8} _, l, err := parsePathResponseFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parsePathResponseFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWritePathResponse(t *testing.T) { frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, []byte{byte(FrameTypePathResponse), 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/ping_frame.go000066400000000000000000000005511513066070600203760ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" ) // A PingFrame is a PING frame type PingFrame struct{} func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return append(b, byte(FrameTypePing)), nil } // Length of a written frame func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 } quic-go-0.59.0/internal/wire/ping_frame_test.go000066400000000000000000000005421513066070600214350ustar00rootroot00000000000000package wire import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestWritePingFrame(t *testing.T) { frame := PingFrame{} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, []byte{0x1}, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/pool.go000066400000000000000000000010551513066070600172400ustar00rootroot00000000000000package wire import ( "sync" "github.com/quic-go/quic-go/internal/protocol" ) var pool sync.Pool func init() { pool.New = func() any { return &StreamFrame{ Data: make([]byte, 0, protocol.MaxPacketBufferSize), fromPool: true, } } } func GetStreamFrame() *StreamFrame { f := pool.Get().(*StreamFrame) return f } func putStreamFrame(f *StreamFrame) { if !f.fromPool { return } if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize { panic("wire.PutStreamFrame called with packet of wrong size!") } pool.Put(f) } quic-go-0.59.0/internal/wire/pool_test.go000066400000000000000000000010141513066070600202720ustar00rootroot00000000000000package wire import ( "testing" "github.com/stretchr/testify/require" ) func TestGetAndPutStreamFrames(t *testing.T) { f := GetStreamFrame() putStreamFrame(f) } func TestPanicOnPuttingStreamFrameWithWrongCapacity(t *testing.T) { f := GetStreamFrame() f.Data = []byte("foobar") require.Panics(t, func() { putStreamFrame(f) }) } func TestAcceptStreamFramesNotFromBuffer(t *testing.T) { f := &StreamFrame{Data: []byte("foobar")} putStreamFrame(f) // No assertion needed as we're just checking it doesn't panic } quic-go-0.59.0/internal/wire/reset_stream_frame.go000066400000000000000000000044731513066070600221450ustar00rootroot00000000000000package wire import ( "fmt" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" ) // A ResetStreamFrame is a RESET_STREAM or RESET_STREAM_AT frame in QUIC type ResetStreamFrame struct { StreamID protocol.StreamID ErrorCode qerr.StreamErrorCode FinalSize protocol.ByteCount ReliableSize protocol.ByteCount } func parseResetStreamFrame(b []byte, isResetStreamAt bool, _ protocol.Version) (*ResetStreamFrame, int, error) { startLen := len(b) streamID, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] errorCode, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] finalSize, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] var reliableSize uint64 if isResetStreamAt { reliableSize, l, err = quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] } if reliableSize > finalSize { return nil, 0, fmt.Errorf("RESET_STREAM_AT: reliable size can't be larger than final size (%d vs %d)", reliableSize, finalSize) } return &ResetStreamFrame{ StreamID: protocol.StreamID(streamID), ErrorCode: qerr.StreamErrorCode(errorCode), FinalSize: protocol.ByteCount(finalSize), ReliableSize: protocol.ByteCount(reliableSize), }, startLen - len(b), nil } func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if f.ReliableSize == 0 { b = quicvarint.Append(b, uint64(FrameTypeResetStream)) } else { b = quicvarint.Append(b, uint64(FrameTypeResetStreamAt)) } b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) b = quicvarint.Append(b, uint64(f.FinalSize)) if f.ReliableSize > 0 { b = quicvarint.Append(b, uint64(f.ReliableSize)) } return b, nil } // Length of a written frame func (f *ResetStreamFrame) Length(protocol.Version) protocol.ByteCount { size := 1 // the frame type for both RESET_STREAM and RESET_STREAM_AT fits into 1 byte if f.ReliableSize > 0 { size += quicvarint.Len(uint64(f.ReliableSize)) } return protocol.ByteCount(size + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))) } quic-go-0.59.0/internal/wire/reset_stream_frame_test.go000066400000000000000000000075071513066070600232050ustar00rootroot00000000000000package wire import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/stretchr/testify/require" ) func TestParseResetStream(t *testing.T) { data := encodeVarInt(0xdeadbeef) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code data = append(data, encodeVarInt(0x987654321)...) // byte offset frame, l, err := parseResetStreamFrame(data, false, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0xdeadbeef), frame.StreamID) require.Equal(t, protocol.ByteCount(0x987654321), frame.FinalSize) require.Equal(t, qerr.StreamErrorCode(0x1337), frame.ErrorCode) require.Equal(t, len(data), l) } func TestParseResetStreamAt(t *testing.T) { data := encodeVarInt(0xabcdef12) // stream ID data = append(data, encodeVarInt(0x2468)...) // error code data = append(data, encodeVarInt(0x123456789)...) // byte offset data = append(data, encodeVarInt(0x789abc)...) // reliable size frame, l, err := parseResetStreamFrame(data, true, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0xabcdef12), frame.StreamID) require.Equal(t, protocol.ByteCount(0x123456789), frame.FinalSize) require.Equal(t, qerr.StreamErrorCode(0x2468), frame.ErrorCode) require.Equal(t, protocol.ByteCount(0x789abc), frame.ReliableSize) require.Equal(t, len(data), l) } func TestParseResetStreamAtSizeTooLarge(t *testing.T) { data := encodeVarInt(0xabcdef12) // stream ID data = append(data, encodeVarInt(0x2468)...) // error code data = append(data, encodeVarInt(1000)...) // byte offset data = append(data, encodeVarInt(1001)...) // reliable size _, _, err := parseResetStreamFrame(data, true, protocol.Version1) require.EqualError(t, err, "RESET_STREAM_AT: reliable size can't be larger than final size (1001 vs 1000)") } func TestParseResetStreamErrorsOnEOFs(t *testing.T) { t.Run("RESET_STREAM", func(t *testing.T) { testParseResetStreamErrorsOnEOFs(t, false) }) t.Run("RESET_STREAM_AT", func(t *testing.T) { testParseResetStreamErrorsOnEOFs(t, true) }) } func testParseResetStreamErrorsOnEOFs(t *testing.T, isResetStreamAt bool) { data := encodeVarInt(0xdeadbeef) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code data = append(data, encodeVarInt(0x987654321)...) // byte offset if isResetStreamAt { data = append(data, encodeVarInt(0x123456)...) // reliable size } _, l, err := parseResetStreamFrame(data, isResetStreamAt, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseResetStreamFrame(data[:i], isResetStreamAt, protocol.Version1) require.Error(t, err) } } func TestWriteResetStream(t *testing.T) { frame := ResetStreamFrame{ StreamID: 0x1337, FinalSize: 0x11223344decafbad, ErrorCode: 0xcafe, } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeResetStream)} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0xcafe)...) expected = append(expected, encodeVarInt(0x11223344decafbad)...) require.Equal(t, expected, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } func TestWriteResetStreamAt(t *testing.T) { frame := ResetStreamFrame{ StreamID: 1337, FinalSize: 42, ErrorCode: 0xcafe, ReliableSize: 12, } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeResetStreamAt)} expected = append(expected, encodeVarInt(1337)...) expected = append(expected, encodeVarInt(0xcafe)...) expected = append(expected, encodeVarInt(42)...) expected = append(expected, encodeVarInt(12)...) require.Equal(t, expected, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/retire_connection_id_frame.go000066400000000000000000000015661513066070600236350ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame type RetireConnectionIDFrame struct { SequenceNumber uint64 } func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnectionIDFrame, int, error) { seq, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } return &RetireConnectionIDFrame{SequenceNumber: seq}, l, nil } func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeRetireConnectionID)) b = quicvarint.Append(b, f.SequenceNumber) return b, nil } // Length of a written frame func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)) } quic-go-0.59.0/internal/wire/retire_connection_id_frame_test.go000066400000000000000000000022331513066070600246640ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseRetireConnectionID(t *testing.T) { data := encodeVarInt(0xdeadbeef) // sequence number frame, l, err := parseRetireConnectionIDFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, uint64(0xdeadbeef), frame.SequenceNumber) require.Equal(t, len(data), l) } func TestParseRetireConnectionIDErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdeadbeef) // sequence number _, l, err := parseRetireConnectionIDFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseRetireConnectionIDFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteRetireConnectionID(t *testing.T) { frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeRetireConnectionID)} expected = append(expected, encodeVarInt(0x1337)...) require.Equal(t, expected, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/short_header.go000066400000000000000000000037141513066070600207420ustar00rootroot00000000000000package wire import ( "errors" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) // ParseShortHeader parses a short header packet. // It must be called after header protection was removed. // Otherwise, the check for the reserved bits will (most likely) fail. func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) { if len(data) == 0 { return 0, 0, 0, 0, io.EOF } if data[0]&0x80 > 0 { return 0, 0, 0, 0, errors.New("not a short header packet") } if data[0]&0x40 == 0 { return 0, 0, 0, 0, errors.New("not a QUIC packet") } pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1 if len(data) < 1+int(pnLen)+connIDLen { return 0, 0, 0, 0, io.EOF } pos := 1 + connIDLen pn, err := readPacketNumber(data[pos:], pnLen) if err != nil { return 0, 0, 0, 0, err } kp := protocol.KeyPhaseZero if data[0]&0b100 > 0 { kp = protocol.KeyPhaseOne } if data[0]&0x18 != 0 { err = ErrInvalidReservedBits } return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err } // AppendShortHeader writes a short header. func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) { typeByte := 0x40 | uint8(pnLen-1) if kp == protocol.KeyPhaseOne { typeByte |= byte(1 << 2) } b = append(b, typeByte) b = append(b, connID.Bytes()...) return appendPacketNumber(b, pn, pnLen) } func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount { return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen) } func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) { logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp) } quic-go-0.59.0/internal/wire/short_header_test.go000066400000000000000000000060761513066070600220050ustar00rootroot00000000000000package wire import ( "bytes" "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseShortHeader(t *testing.T) { data := []byte{ 0b01000110, 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, 0x99, } l, pn, pnLen, kp, err := ParseShortHeader(data, 4) require.NoError(t, err) require.Equal(t, len(data), l) require.Equal(t, protocol.KeyPhaseOne, kp) require.Equal(t, protocol.PacketNumber(0x133799), pn) require.Equal(t, protocol.PacketNumberLen3, pnLen) } func TestParseShortHeaderNoQUICBit(t *testing.T) { data := []byte{ 0b00000101, 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, } _, _, _, _, err := ParseShortHeader(data, 4) require.EqualError(t, err, "not a QUIC packet") } func TestParseShortHeaderReservedBitsSet(t *testing.T) { data := []byte{ 0b01010101, 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, } _, pn, _, _, err := ParseShortHeader(data, 4) require.EqualError(t, err, ErrInvalidReservedBits.Error()) require.Equal(t, protocol.PacketNumber(0x1337), pn) } func TestParseShortHeaderErrorsWhenPassedLongHeaderPacket(t *testing.T) { _, _, _, _, err := ParseShortHeader([]byte{0x80}, 4) require.EqualError(t, err, "not a short header packet") } func TestParseShortHeaderErrorsOnEOF(t *testing.T) { data := []byte{ 0b01000110, 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, 0x99, } _, _, _, _, err := ParseShortHeader(data, 4) require.NoError(t, err) for i := range data { _, _, _, _, err := ParseShortHeader(data[:i], 4) require.EqualError(t, err, io.EOF.Error()) } } func TestShortHeaderLen(t *testing.T) { require.Equal(t, protocol.ByteCount(8), ShortHeaderLen(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), protocol.PacketNumberLen3)) require.Equal(t, protocol.ByteCount(2), ShortHeaderLen(protocol.ParseConnectionID([]byte{}), protocol.PacketNumberLen1)) } func TestWriteShortHeaderPacket(t *testing.T) { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) b, err := AppendShortHeader(nil, connID, 1337, 4, protocol.KeyPhaseOne) require.NoError(t, err) l, pn, pnLen, kp, err := ParseShortHeader(b, 4) require.NoError(t, err) require.Equal(t, protocol.PacketNumber(1337), pn) require.Equal(t, protocol.PacketNumberLen4, pnLen) require.Equal(t, protocol.KeyPhaseOne, kp) require.Equal(t, len(b), l) } func TestLogShortHeaderWithConnectionID(t *testing.T) { buf := &bytes.Buffer{} logger := setupLogTest(t, buf) connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) LogShortHeader(logger, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne) require.Contains(t, buf.String(), "Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}") } func BenchmarkWriteShortHeader(b *testing.B) { b.ReportAllocs() buf := make([]byte, 100) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) for b.Loop() { var err error buf, err = AppendShortHeader(buf, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne) if err != nil { b.Fatalf("failed to write short header: %s", err) } buf = buf[:0] } } quic-go-0.59.0/internal/wire/stop_sending_frame.go000066400000000000000000000024061513066070600221360ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" ) // A StopSendingFrame is a STOP_SENDING frame type StopSendingFrame struct { StreamID protocol.StreamID ErrorCode qerr.StreamErrorCode } // parseStopSendingFrame parses a STOP_SENDING frame func parseStopSendingFrame(b []byte, _ protocol.Version) (*StopSendingFrame, int, error) { startLen := len(b) streamID, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] errorCode, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] return &StopSendingFrame{ StreamID: protocol.StreamID(streamID), ErrorCode: qerr.StreamErrorCode(errorCode), }, startLen - len(b), nil } // Length of a written frame func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.ErrorCode))) } func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, byte(FrameTypeStopSending)) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) return b, nil } quic-go-0.59.0/internal/wire/stop_sending_frame_test.go000066400000000000000000000027061513066070600232000ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/stretchr/testify/require" ) func TestParseStopSending(t *testing.T) { data := encodeVarInt(0xdecafbad) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code frame, l, err := parseStopSendingFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0xdecafbad), frame.StreamID) require.Equal(t, qerr.StreamErrorCode(0x1337), frame.ErrorCode) require.Equal(t, len(data), l) } func TestParseStopSendingErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdecafbad) // stream ID data = append(data, encodeVarInt(0x123456)...) // error code _, l, err := parseStopSendingFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseStopSendingFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteStopSendingFrame(t *testing.T) { frame := &StopSendingFrame{ StreamID: 0xdeadbeefcafe, ErrorCode: 0xdecafbad, } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeStopSending)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) expected = append(expected, encodeVarInt(0xdecafbad)...) require.Equal(t, expected, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } quic-go-0.59.0/internal/wire/stream_data_blocked_frame.go000066400000000000000000000023231513066070600234070ustar00rootroot00000000000000package wire import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame type StreamDataBlockedFrame struct { StreamID protocol.StreamID MaximumStreamData protocol.ByteCount } func parseStreamDataBlockedFrame(b []byte, _ protocol.Version) (*StreamDataBlockedFrame, int, error) { startLen := len(b) sid, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] offset, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } return &StreamDataBlockedFrame{ StreamID: protocol.StreamID(sid), MaximumStreamData: protocol.ByteCount(offset), }, startLen - len(b) + l, nil } func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, 0x15) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) return b, nil } // Length of a written frame func (f *StreamDataBlockedFrame) Length(protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData))) } quic-go-0.59.0/internal/wire/stream_data_blocked_frame_test.go000066400000000000000000000027031513066070600244500ustar00rootroot00000000000000package wire import ( "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseStreamDataBlocked(t *testing.T) { data := encodeVarInt(0xdeadbeef) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset frame, l, err := parseStreamDataBlockedFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0xdeadbeef), frame.StreamID) require.Equal(t, protocol.ByteCount(0xdecafbad), frame.MaximumStreamData) require.Equal(t, len(data), l) } func TestParseStreamDataBlockedErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0xdeadbeef) data = append(data, encodeVarInt(0xc0010ff)...) _, l, err := parseStreamDataBlockedFrame(data, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseStreamDataBlockedFrame(data[:i], protocol.Version1) require.Equal(t, io.EOF, err) } } func TestWriteStreamDataBlocked(t *testing.T) { f := &StreamDataBlockedFrame{ StreamID: 0xdecafbad, MaximumStreamData: 0x1337, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeStreamDataBlocked)} expected = append(expected, encodeVarInt(uint64(f.StreamID))...) expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } quic-go-0.59.0/internal/wire/stream_frame.go000066400000000000000000000114441513066070600207370ustar00rootroot00000000000000package wire import ( "errors" "io" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A StreamFrame of QUIC type StreamFrame struct { StreamID protocol.StreamID Offset protocol.ByteCount Data []byte Fin bool DataLenPresent bool fromPool bool } func ParseStreamFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamFrame, int, error) { startLen := len(b) hasOffset := typ&0b100 > 0 fin := typ&0b1 > 0 hasDataLen := typ&0b10 > 0 streamID, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] var offset uint64 if hasOffset { offset, l, err = quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] } var dataLen uint64 if hasDataLen { var err error var l int dataLen, l, err = quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } b = b[l:] if dataLen > uint64(len(b)) { return nil, 0, io.EOF } } else { // The rest of the packet is data dataLen = uint64(len(b)) } var frame *StreamFrame if dataLen < protocol.MinStreamFrameBufferSize { frame = &StreamFrame{} if dataLen > 0 { frame.Data = make([]byte, dataLen) } } else { frame = GetStreamFrame() // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, // since those StreamFrames have a buffer length of the maximum packet size. if dataLen > uint64(cap(frame.Data)) { return nil, 0, io.EOF } frame.Data = frame.Data[:dataLen] } frame.StreamID = protocol.StreamID(streamID) frame.Offset = protocol.ByteCount(offset) frame.Fin = fin frame.DataLenPresent = hasDataLen if dataLen > 0 { copy(frame.Data, b) } if frame.Offset+frame.DataLen() > protocol.MaxByteCount { return nil, 0, errors.New("stream data overflows maximum offset") } return frame, startLen - len(b) + int(dataLen), nil } func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if len(f.Data) == 0 && !f.Fin { return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") } typ := byte(0x8) if f.Fin { typ ^= 0b1 } hasOffset := f.Offset != 0 if f.DataLenPresent { typ ^= 0b10 } if hasOffset { typ ^= 0b100 } b = append(b, typ) b = quicvarint.Append(b, uint64(f.StreamID)) if hasOffset { b = quicvarint.Append(b, uint64(f.Offset)) } if f.DataLenPresent { b = quicvarint.Append(b, uint64(f.DataLen())) } b = append(b, f.Data...) return b, nil } // Length returns the total length of the STREAM frame func (f *StreamFrame) Length(protocol.Version) protocol.ByteCount { length := 1 + quicvarint.Len(uint64(f.StreamID)) if f.Offset != 0 { length += quicvarint.Len(uint64(f.Offset)) } if f.DataLenPresent { length += quicvarint.Len(uint64(f.DataLen())) } return protocol.ByteCount(length) + f.DataLen() } // DataLen gives the length of data in bytes func (f *StreamFrame) DataLen() protocol.ByteCount { return protocol.ByteCount(len(f.Data)) } // MaxDataLen returns the maximum data length // If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, _ protocol.Version) protocol.ByteCount { headerLen := 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))) if f.Offset != 0 { headerLen += protocol.ByteCount(quicvarint.Len(uint64(f.Offset))) } if f.DataLenPresent { // Pretend that the data size will be 1 byte. // If it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterward headerLen++ } if headerLen > maxSize { return 0 } maxDataLen := maxSize - headerLen if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { maxDataLen-- } return maxDataLen } // MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. // It returns if the frame was actually split. // The frame might not be split if: // * the size is large enough to fit the whole frame // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) { if maxSize >= f.Length(version) { return nil, false } n := f.MaxDataLen(maxSize, version) if n == 0 { return nil, true } new := GetStreamFrame() new.StreamID = f.StreamID new.Offset = f.Offset new.Fin = false new.DataLenPresent = f.DataLenPresent // swap the data slices new.Data, f.Data = f.Data, new.Data new.fromPool, f.fromPool = f.fromPool, new.fromPool f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n] copy(f.Data, new.Data[n:]) new.Data = new.Data[:n] f.Offset += n return new, true } func (f *StreamFrame) PutBack() { putStreamFrame(f) } quic-go-0.59.0/internal/wire/stream_frame_test.go000066400000000000000000000317401513066070600217770ustar00rootroot00000000000000package wire import ( "bytes" "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseStreamFrameWithOffBit(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, []byte("foobar")...) frame, l, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, []byte("foobar"), frame.Data) require.False(t, frame.Fin) require.Equal(t, protocol.ByteCount(0xdecafbad), frame.Offset) require.Equal(t, len(data), l) } func TestParseStreamFrameRespectsLEN(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(4)...) // data length data = append(data, []byte("foobar")...) frame, l, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, []byte("foob"), frame.Data) require.False(t, frame.Fin) require.Zero(t, frame.Offset) require.Equal(t, len(data)-2, l) } func TestParseStreamFrameWithFINBit(t *testing.T) { data := encodeVarInt(9) // stream ID data = append(data, []byte("foobar")...) frame, l, err := ParseStreamFrame(data, 0x8^0x1, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(9), frame.StreamID) require.Equal(t, []byte("foobar"), frame.Data) require.True(t, frame.Fin) require.Zero(t, frame.Offset) require.Equal(t, len(data), l) } func TestParseStreamFrameAllowsEmpty(t *testing.T) { data := encodeVarInt(0x1337) // stream ID data = append(data, encodeVarInt(0x12345)...) // offset f, l, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x1337), f.StreamID) require.Equal(t, protocol.ByteCount(0x12345), f.Offset) require.Nil(t, f.Data) require.False(t, f.Fin) require.Equal(t, len(data), l) } func TestParseStreamFrameRejectsOverflow(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset data = append(data, []byte("foobar")...) _, _, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1) require.EqualError(t, err, "stream data overflows maximum offset") } func TestParseStreamFrameRejectsLongFrames(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) _, _, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1) require.Equal(t, io.EOF, err) } func TestParseStreamFrameRejectsFramesExceedingRemainingSize(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(7)...) // data length data = append(data, []byte("foobar")...) _, _, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1) require.Equal(t, io.EOF, err) } func TestParseStreamFrameErrorsOnEOFs(t *testing.T) { typ := uint64(0x8 ^ 0x4 ^ 0x2) data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) _, _, err := ParseStreamFrame(data, FrameType(typ), protocol.Version1) require.NoError(t, err) for i := range data { _, _, err = ParseStreamFrame(data[:i], FrameType(typ), protocol.Version1) require.Error(t, err) } } func TestParseStreamUsesBufferForLongFrames(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) frame, l, err := ParseStreamFrame(data, 0x8, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize), frame.Data) require.Equal(t, protocol.ByteCount(protocol.MinStreamFrameBufferSize), frame.DataLen()) require.False(t, frame.Fin) require.True(t, frame.fromPool) require.Equal(t, len(data), l) require.NotPanics(t, frame.PutBack) } func TestParseStreamDoesNotUseBufferForShortFrames(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) frame, l, err := ParseStreamFrame(data, 0x8, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1), frame.Data) require.Equal(t, protocol.ByteCount(protocol.MinStreamFrameBufferSize-1), frame.DataLen()) require.False(t, frame.Fin) require.False(t, frame.fromPool) require.Equal(t, len(data), l) require.NotPanics(t, frame.PutBack) } func TestWriteStreamFrameWithoutOffset(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, Data: []byte("foobar"), } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x8} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestWriteStreamFrameWithOffset(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, Offset: 0x123456, Data: []byte("foobar"), } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x8 ^ 0x4} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(0x123456)...) // offset expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestWriteStreamFrameWithFIN(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, Offset: 0x123456, Fin: true, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x8 ^ 0x4 ^ 0x1} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(0x123456)...) // offset require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestWriteStreamFrameWithDataLength(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, Data: []byte("foobar"), DataLenPresent: true, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x8 ^ 0x2} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(6)...) // data length expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestWriteStreamFrameWithDataLengthAndOffset(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, Data: []byte("foobar"), DataLenPresent: true, Offset: 0x123456, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{0x8 ^ 0x4 ^ 0x2} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(0x123456)...) // offset expected = append(expected, encodeVarInt(6)...) // data length expected = append(expected, []byte("foobar")...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestWriteStreamFrameEmptyFrameWithoutFIN(t *testing.T) { f := &StreamFrame{ StreamID: 0x42, Offset: 0x1337, } _, err := f.Append(nil, protocol.Version1) require.EqualError(t, err, "StreamFrame: attempting to write empty frame without FIN") } func TestStreamMaxDataLength(t *testing.T) { const maxSize = 3000 data := make([]byte, maxSize) f := &StreamFrame{ StreamID: 0x1337, Offset: 0xdeadbeef, } for i := 1; i < 3000; i++ { f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Greater(t, len(b), i) continue } f.Data = data[:int(maxDataLen)] b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Equal(t, i, len(b)) } } func TestStreamMaxDataLengthWithDataLenPresent(t *testing.T) { const maxSize = 3000 data := make([]byte, maxSize) f := &StreamFrame{ StreamID: 0x1337, Offset: 0xdeadbeef, DataLenPresent: true, } var frameOneByteTooSmallCounter int for i := 1; i < 3000; i++ { f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Greater(t, len(b), i) continue } f.Data = data[:int(maxDataLen)] b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size if len(b) == i-1 { frameOneByteTooSmallCounter++ continue } require.Equal(t, i, len(b)) } require.Equal(t, 1, frameOneByteTooSmallCounter) } func TestStreamSplitting(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, DataLenPresent: true, Offset: 0x100, Data: []byte("foobar"), } frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1)-3, protocol.Version1) require.True(t, needsSplit) require.NotNil(t, frame) require.True(t, f.DataLenPresent) require.True(t, frame.DataLenPresent) require.Equal(t, protocol.ByteCount(0x100), frame.Offset) require.Equal(t, []byte("foo"), frame.Data) require.Equal(t, protocol.ByteCount(0x100+3), f.Offset) require.Equal(t, []byte("bar"), f.Data) } func TestStreamSplittingNoSplitForShortFrame(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, DataLenPresent: true, Offset: 0xdeadbeef, Data: make([]byte, 100), } frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) require.False(t, needsSplit) require.Nil(t, frame) require.Equal(t, protocol.ByteCount(100), f.DataLen()) frame, needsSplit = f.MaybeSplitOffFrame(f.Length(protocol.Version1)-1, protocol.Version1) require.True(t, needsSplit) require.Equal(t, protocol.ByteCount(99), frame.DataLen()) f.PutBack() } func TestStreamSplittingPreservesFINBit(t *testing.T) { f := &StreamFrame{ StreamID: 0x1337, Fin: true, Offset: 0xdeadbeef, Data: make([]byte, 100), } frame, needsSplit := f.MaybeSplitOffFrame(50, protocol.Version1) require.True(t, needsSplit) require.NotNil(t, frame) require.Less(t, frame.Offset, f.Offset) require.True(t, f.Fin) require.False(t, frame.Fin) } func TestStreamSplittingProducesCorrectLengthFramesWithoutDataLen(t *testing.T) { const size = 1000 f := &StreamFrame{ StreamID: 0xdecafbad, Offset: 0x1234, Data: []byte{0}, } minFrameSize := f.Length(protocol.Version1) for i := protocol.ByteCount(0); i < minFrameSize; i++ { frame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) require.True(t, needsSplit) require.Nil(t, frame) } for i := minFrameSize; i < size; i++ { f.fromPool = false f.Data = make([]byte, size) frame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) require.True(t, needsSplit) require.Equal(t, i, frame.Length(protocol.Version1)) } } func TestStreamSplittingProducesCorrectLengthFramesWithDataLen(t *testing.T) { const size = 1000 f := &StreamFrame{ StreamID: 0xdecafbad, Offset: 0x1234, DataLenPresent: true, Data: []byte{0}, } minFrameSize := f.Length(protocol.Version1) for i := protocol.ByteCount(0); i < minFrameSize; i++ { frame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) require.True(t, needsSplit) require.Nil(t, frame) } var frameOneByteTooSmallCounter int for i := minFrameSize; i < size; i++ { f.fromPool = false f.Data = make([]byte, size) newFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) require.True(t, needsSplit) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size if newFrame.Length(protocol.Version1) == i-1 { frameOneByteTooSmallCounter++ continue } require.Equal(t, i, newFrame.Length(protocol.Version1)) } require.Equal(t, 1, frameOneByteTooSmallCounter) } quic-go-0.59.0/internal/wire/streams_blocked_frame.go000066400000000000000000000027411513066070600226050ustar00rootroot00000000000000package wire import ( "fmt" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) // A StreamsBlockedFrame is a STREAMS_BLOCKED frame type StreamsBlockedFrame struct { Type protocol.StreamType StreamLimit protocol.StreamNum } func parseStreamsBlockedFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamsBlockedFrame, int, error) { f := &StreamsBlockedFrame{} //nolint:exhaustive // This will only be called with a BidiStreamBlockedFrameType or a UniStreamBlockedFrameType. switch typ { case FrameTypeBidiStreamBlocked: f.Type = protocol.StreamTypeBidi case FrameTypeUniStreamBlocked: f.Type = protocol.StreamTypeUni } streamLimit, l, err := quicvarint.Parse(b) if err != nil { return nil, 0, replaceUnexpectedEOF(err) } f.StreamLimit = protocol.StreamNum(streamLimit) if f.StreamLimit > protocol.MaxStreamCount { return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) } return f, l, nil } func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: b = append(b, byte(FrameTypeBidiStreamBlocked)) case protocol.StreamTypeUni: b = append(b, byte(FrameTypeUniStreamBlocked)) } b = quicvarint.Append(b, uint64(f.StreamLimit)) return b, nil } // Length of a written frame func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamLimit))) } quic-go-0.59.0/internal/wire/streams_blocked_frame_test.go000066400000000000000000000073221513066070600236440ustar00rootroot00000000000000package wire import ( "fmt" "io" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func TestParseStreamsBlockedFrameBidirectional(t *testing.T) { data := encodeVarInt(0x1337) f, l, err := parseStreamsBlockedFrame(data, FrameTypeBidiStreamBlocked, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeBidi, f.Type) require.EqualValues(t, 0x1337, f.StreamLimit) require.Equal(t, len(data), l) } func TestParseStreamsBlockedFrameUnidirectional(t *testing.T) { data := encodeVarInt(0x7331) f, l, err := parseStreamsBlockedFrame(data, FrameTypeUniStreamBlocked, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeUni, f.Type) require.EqualValues(t, 0x7331, f.StreamLimit) require.Equal(t, len(data), l) } func TestParseStreamsBlockedFrameErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0x12345678) _, l, err := parseStreamsBlockedFrame(data, FrameTypeBidiStreamBlocked, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { _, _, err := parseStreamsBlockedFrame(data[:i], FrameTypeBidiStreamBlocked, protocol.Version1) require.Equal(t, io.EOF, err) } } func TestParseStreamsBlockedFrameMaxStreamCount(t *testing.T) { for _, streamType := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { var streamTypeStr string if streamType == protocol.StreamTypeUni { streamTypeStr = "unidirectional" } else { streamTypeStr = "bidirectional" } t.Run(streamTypeStr, func(t *testing.T) { f := &StreamsBlockedFrame{ Type: streamType, StreamLimit: protocol.MaxStreamCount, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] frame, l, err := parseStreamsBlockedFrame(b, FrameType(typ), protocol.Version1) require.NoError(t, err) require.Equal(t, f, frame) require.Equal(t, len(b), l) }) } } func TestParseStreamsBlockedFrameErrorOnTooLargeStreamCount(t *testing.T) { for _, streamType := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { var streamTypeStr string if streamType == protocol.StreamTypeUni { streamTypeStr = "unidirectional" } else { streamTypeStr = "bidirectional" } t.Run(streamTypeStr, func(t *testing.T) { f := &StreamsBlockedFrame{ Type: streamType, StreamLimit: protocol.MaxStreamCount + 1, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] _, _, err = parseStreamsBlockedFrame(b, FrameType(typ), protocol.Version1) require.EqualError(t, err, fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1)) }) } } func TestWriteStreamsBlockedFrameBidirectional(t *testing.T) { f := StreamsBlockedFrame{ Type: protocol.StreamTypeBidi, StreamLimit: 0xdeadbeefcafe, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeBidiStreamBlocked)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } func TestWriteStreamsBlockedFrameUnidirectional(t *testing.T) { f := StreamsBlockedFrame{ Type: protocol.StreamTypeUni, StreamLimit: 0xdeadbeefcafe, } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) expected := []byte{byte(FrameTypeUniStreamBlocked)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) } quic-go-0.59.0/internal/wire/test_helpers_test.go000066400000000000000000000013411513066070600220250ustar00rootroot00000000000000package wire import ( "bytes" "encoding/binary" "log" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" ) func encodeVarInt(i uint64) []byte { return quicvarint.Append(nil, i) } func appendVersion(data []byte, v protocol.Version) []byte { offset := len(data) data = append(data, []byte{0, 0, 0, 0}...) binary.BigEndian.PutUint32(data[offset:], uint32(v)) return data } func setupLogTest(t *testing.T, buf *bytes.Buffer) utils.Logger { logger := utils.DefaultLogger logger.SetLogLevel(utils.LogLevelDebug) originalOutput := log.Writer() log.SetOutput(buf) t.Cleanup(func() { log.SetOutput(originalOutput) }) return logger } quic-go-0.59.0/internal/wire/transport_parameter_test.go000066400000000000000000001101611513066070600234210ustar00rootroot00000000000000package wire import ( "bytes" "crypto/rand" "fmt" "math" mrand "math/rand/v2" "net/netip" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) func getRandomValueUpTo(max uint64) uint64 { maxVals := []uint64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} return mrand.Uint64N(min(max, maxVals[mrand.IntN(4)])) } func getRandomValue() uint64 { return getRandomValueUpTo(quicvarint.Max) } func appendInitialSourceConnectionID(b []byte) []byte { b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID)) b = quicvarint.Append(b, 6) return append(b, []byte("foobar")...) } func TestTransportParametersStringRepresentation(t *testing.T) { rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) minAckDelay := 42 * time.Millisecond p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42 * time.Second, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), RetrySourceConnectionID: &rcid, AckDelayExponent: 14, MaxAckDelay: 37 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, ActiveConnectionIDLimit: 123, MaxDatagramFrameSize: 876, EnableResetStreamAt: true, MinAckDelay: &minAckDelay, } expected := "&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: decafbad, RetrySourceConnectionID: deadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876, EnableResetStreamAt: true, MinAckDelay: 42ms}" require.Equal(t, expected, p.String()) } func TestTransportParametersStringRepresentationWithoutOptionalFields(t *testing.T) { p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42 * time.Second, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), InitialSourceConnectionID: protocol.ParseConnectionID([]byte{}), AckDelayExponent: 14, MaxAckDelay: 37 * time.Second, ActiveConnectionIDLimit: 89, MaxDatagramFrameSize: protocol.InvalidByteCount, } expected := "&wire.TransportParameters{OriginalDestinationConnectionID: deadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89, EnableResetStreamAt: false}" require.Equal(t, expected, p.String()) } func TestMarshalAndUnmarshalTransportParameters(t *testing.T) { var token protocol.StatelessResetToken rand.Read(token[:]) rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) minAckDelay := 42 * time.Millisecond params := &TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), InitialMaxData: protocol.ByteCount(getRandomValue()), MaxIdleTimeout: 0xcafe * time.Second, MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(uint64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(uint64(protocol.MaxStreamCount))), DisableActiveMigration: true, StatelessResetToken: &token, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), RetrySourceConnectionID: &rcid, AckDelayExponent: 13, MaxAckDelay: 42 * time.Millisecond, ActiveConnectionIDLimit: 2 + getRandomValueUpTo(quicvarint.Max-2), MaxUDPPayloadSize: 1200 + protocol.ByteCount(getRandomValueUpTo(quicvarint.Max-1200)), MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), EnableResetStreamAt: getRandomValue()%2 == 0, MinAckDelay: &minAckDelay, } data := params.Marshal(protocol.PerspectiveServer) p := &TransportParameters{} require.NoError(t, p.Unmarshal(data, protocol.PerspectiveServer)) require.Equal(t, params.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiLocal) require.Equal(t, params.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataBidiRemote) require.Equal(t, params.InitialMaxStreamDataUni, p.InitialMaxStreamDataUni) require.Equal(t, params.InitialMaxData, p.InitialMaxData) require.Equal(t, params.MaxUniStreamNum, p.MaxUniStreamNum) require.Equal(t, params.MaxBidiStreamNum, p.MaxBidiStreamNum) require.Equal(t, params.MaxIdleTimeout, p.MaxIdleTimeout) require.Equal(t, params.DisableActiveMigration, p.DisableActiveMigration) require.Equal(t, params.StatelessResetToken, p.StatelessResetToken) require.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), p.OriginalDestinationConnectionID) require.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), p.InitialSourceConnectionID) require.Equal(t, &rcid, p.RetrySourceConnectionID) require.Equal(t, uint8(13), p.AckDelayExponent) require.Equal(t, 42*time.Millisecond, p.MaxAckDelay) require.Equal(t, params.ActiveConnectionIDLimit, p.ActiveConnectionIDLimit) require.Equal(t, params.MaxUDPPayloadSize, p.MaxUDPPayloadSize) require.Equal(t, params.MaxDatagramFrameSize, p.MaxDatagramFrameSize) require.Equal(t, params.EnableResetStreamAt, p.EnableResetStreamAt) require.NotNil(t, p.MinAckDelay) require.Equal(t, minAckDelay, *p.MinAckDelay) } func TestMarshalAdditionalTransportParameters(t *testing.T) { origAdditionalTransportParametersClient := AdditionalTransportParametersClient t.Cleanup(func() { AdditionalTransportParametersClient = origAdditionalTransportParametersClient }) AdditionalTransportParametersClient = map[uint64][]byte{1337: []byte("foobar")} result := quicvarint.Append([]byte{}, 1337) result = quicvarint.Append(result, 6) result = append(result, []byte("foobar")...) params := &TransportParameters{} require.True(t, bytes.Contains(params.Marshal(protocol.PerspectiveClient), result)) require.False(t, bytes.Contains(params.Marshal(protocol.PerspectiveServer), result)) } func TestMarshalRetrySourceConnectionID(t *testing.T) { // no retry source connection ID data := (&TransportParameters{ StatelessResetToken: &protocol.StatelessResetToken{}, ActiveConnectionIDLimit: 2, }).Marshal(protocol.PerspectiveServer) var p TransportParameters require.NoError(t, p.Unmarshal(data, protocol.PerspectiveServer)) require.Nil(t, p.RetrySourceConnectionID) // zero-length retry source connection ID rcid := protocol.ParseConnectionID([]byte{}) data = (&TransportParameters{ RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{}, ActiveConnectionIDLimit: 2, }).Marshal(protocol.PerspectiveServer) p = TransportParameters{} require.NoError(t, p.Unmarshal(data, protocol.PerspectiveServer)) require.NotNil(t, p.RetrySourceConnectionID) require.Zero(t, p.RetrySourceConnectionID.Len()) } func TestTransportParameterNoMaxAckDelayIfDefault(t *testing.T) { const num = 1000 var defaultLen, dataLen int maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond for range num { dataDefault := (&TransportParameters{ MaxAckDelay: protocol.DefaultMaxAckDelay, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) defaultLen += len(dataDefault) data := (&TransportParameters{ MaxAckDelay: maxAckDelay, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) dataLen += len(data) } entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) + quicvarint.Len(uint64(quicvarint.Len(uint64(maxAckDelay.Milliseconds())))) + quicvarint.Len(uint64(maxAckDelay.Milliseconds())) require.InDelta(t, float32(defaultLen)/num+float32(entryLen), float32(dataLen)/num, 1) } func TestTransportParameterNoAckDelayExponentIfDefault(t *testing.T) { const num = 1000 var defaultLen, dataLen int for range num { dataDefault := (&TransportParameters{ AckDelayExponent: protocol.DefaultAckDelayExponent, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) defaultLen += len(dataDefault) data := (&TransportParameters{ AckDelayExponent: protocol.DefaultAckDelayExponent + 1, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) dataLen += len(data) } entryLen := quicvarint.Len(uint64(ackDelayExponentParameterID)) + quicvarint.Len(uint64(quicvarint.Len(protocol.DefaultAckDelayExponent+1))) + quicvarint.Len(protocol.DefaultAckDelayExponent+1) require.InDelta(t, float32(defaultLen)/num+float32(entryLen), float32(dataLen)/num, 1) } func TestTransportParameterSetsDefaultValuesWhenNotSent(t *testing.T) { data := (&TransportParameters{ AckDelayExponent: protocol.DefaultAckDelayExponent, StatelessResetToken: &protocol.StatelessResetToken{}, ActiveConnectionIDLimit: protocol.DefaultActiveConnectionIDLimit, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} require.NoError(t, p.Unmarshal(data, protocol.PerspectiveServer)) require.EqualValues(t, protocol.DefaultAckDelayExponent, p.AckDelayExponent) require.EqualValues(t, protocol.DefaultActiveConnectionIDLimit, p.ActiveConnectionIDLimit) } func TestTransportParameterErrors(t *testing.T) { tests := []struct { name string params *TransportParameters perspective protocol.Perspective data []byte expectedErrMsg string }{ { name: "invalid stateless reset token length", data: func() []byte { b := quicvarint.Append(nil, uint64(statelessResetTokenParameterID)) b = quicvarint.Append(b, 15) return append(b, make([]byte, 15)...) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "wrong length for stateless_reset_token: 15 (expected 16)", }, { name: "small max UDP payload size", data: func() []byte { b := quicvarint.Append(nil, uint64(maxUDPPayloadSizeParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(1199))) return quicvarint.Append(b, 1199) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "invalid value for max_udp_payload_size: 1199 (minimum 1200)", }, { name: "active connection ID limit too small", params: &TransportParameters{ ActiveConnectionIDLimit: 1, StatelessResetToken: &protocol.StatelessResetToken{}, }, perspective: protocol.PerspectiveServer, expectedErrMsg: "invalid value for active_connection_id_limit: 1 (minimum 2)", }, { name: "ack delay exponent too large", params: &TransportParameters{ AckDelayExponent: 21, StatelessResetToken: &protocol.StatelessResetToken{}, }, perspective: protocol.PerspectiveServer, expectedErrMsg: "invalid value for ack_delay_exponent: 21 (maximum 20)", }, { name: "disable active migration has content", data: func() []byte { b := quicvarint.Append(nil, uint64(disableActiveMigrationParameterID)) b = quicvarint.Append(b, 6) return append(b, []byte("foobar")...) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "wrong length for disable_active_migration: 6 (expected empty)", }, { name: "server doesn't set original destination connection ID", data: func() []byte { b := quicvarint.Append(nil, uint64(statelessResetTokenParameterID)) b = quicvarint.Append(b, 16) b = append(b, make([]byte, 16)...) return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "missing original_destination_connection_id", }, { name: "initial source connection ID is missing", data: []byte{}, perspective: protocol.PerspectiveClient, expectedErrMsg: "missing initial_source_connection_id", }, { name: "max ack delay is too large", params: &TransportParameters{ MaxAckDelay: 1 << 14 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{}, }, perspective: protocol.PerspectiveServer, expectedErrMsg: "invalid value for max_ack_delay: 16384ms (maximum 16383ms)", }, { name: "varint value has wrong length", data: func() []byte { b := quicvarint.Append(nil, uint64(initialMaxStreamDataBidiLocalParameterID)) b = quicvarint.Append(b, 2) val := uint64(0xdeadbeef) b = quicvarint.Append(b, val) return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: fmt.Sprintf("inconsistent transport parameter length for transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), }, { name: "initial max streams bidi is too large", data: func() []byte { b := quicvarint.Append(nil, uint64(initialMaxStreamsBidiParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) b = quicvarint.Append(b, uint64(protocol.MaxStreamCount+1)) return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)", }, { name: "initial max streams uni is too large", data: func() []byte { b := quicvarint.Append(nil, uint64(initialMaxStreamsUniParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) b = quicvarint.Append(b, uint64(protocol.MaxStreamCount+1)) return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)", }, { name: "not enough data to read", data: func() []byte { b := quicvarint.Append(nil, 0x42) b = quicvarint.Append(b, 7) return append(b, []byte("foobar")...) }(), perspective: protocol.PerspectiveServer, expectedErrMsg: "remaining length (6) smaller than parameter length (7)", }, { name: "client sent stateless reset token", data: func() []byte { b := quicvarint.Append(nil, uint64(statelessResetTokenParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(16))) return append(b, make([]byte, 16)...) }(), perspective: protocol.PerspectiveClient, expectedErrMsg: "client sent a stateless_reset_token", }, { name: "client sent original destination connection ID", data: func() []byte { b := quicvarint.Append(nil, uint64(originalDestinationConnectionIDParameterID)) b = quicvarint.Append(b, 6) return append(b, []byte("foobar")...) }(), perspective: protocol.PerspectiveClient, expectedErrMsg: "client sent an original_destination_connection_id", }, { name: "huge max ack delay value", data: func() []byte { val := uint64(math.MaxUint64) / 5 b := quicvarint.Append(nil, uint64(maxAckDelayParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(val))) b = quicvarint.Append(b, val) return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveClient, expectedErrMsg: "invalid value for max_ack_delay: 3689348814741910323ms (maximum 16383ms)", }, { name: "invalid value for reset_stream_at", data: func() []byte { b := quicvarint.Append(nil, uint64(resetStreamAtParameterID)) b = quicvarint.Append(b, 1) b = quicvarint.Append(b, 1) return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveClient, expectedErrMsg: "wrong length for reset_stream_at: 1 (expected empty)", }, { name: "min ack delay is greater than max ack delay", data: func() []byte { b := quicvarint.Append(nil, uint64(minAckDelayParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(42001))) b = quicvarint.Append(b, 42001) // 42001 microseconds b = quicvarint.Append(b, uint64(maxAckDelayParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(42))) b = quicvarint.Append(b, 42) // 42 microseconds return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveClient, expectedErrMsg: "min_ack_delay (42.001ms) is greater than max_ack_delay (42ms)", }, { name: "huge min ack delay value", data: func() []byte { b := quicvarint.Append(nil, uint64(minAckDelayParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(quicvarint.Max))) b = quicvarint.Append(b, quicvarint.Max) b = quicvarint.Append(b, uint64(maxAckDelayParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(42))) b = quicvarint.Append(b, 42) // 42 microseconds return appendInitialSourceConnectionID(b) }(), perspective: protocol.PerspectiveClient, expectedErrMsg: "min_ack_delay (2562047h47m16.854775807s) is greater than max_ack_delay (42ms)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var err error if tt.params != nil { data := tt.params.Marshal(tt.perspective) err = (&TransportParameters{}).Unmarshal(data, tt.perspective) } else { err = (&TransportParameters{}).Unmarshal(tt.data, tt.perspective) } require.Error(t, err) transportErr, ok := err.(*qerr.TransportError) require.True(t, ok) require.Equal(t, qerr.TransportParameterError, transportErr.ErrorCode) require.Equal(t, tt.expectedErrMsg, transportErr.ErrorMessage) }) } } func TestTransportParameterUnknownParameters(t *testing.T) { // write a known parameter b := quicvarint.Append(nil, uint64(initialMaxStreamDataBidiLocalParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(0x1337))) b = quicvarint.Append(b, 0x1337) // write an unknown parameter b = quicvarint.Append(b, 0x42) b = quicvarint.Append(b, 6) b = append(b, []byte("foobar")...) // write a known parameter b = quicvarint.Append(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(0x42))) b = quicvarint.Append(b, 0x42) b = appendInitialSourceConnectionID(b) p := &TransportParameters{} err := p.Unmarshal(b, protocol.PerspectiveClient) require.NoError(t, err) require.Equal(t, protocol.ByteCount(0x1337), p.InitialMaxStreamDataBidiLocal) require.Equal(t, protocol.ByteCount(0x42), p.InitialMaxStreamDataBidiRemote) } func TestTransportParameterRejectsDuplicateParameters(t *testing.T) { // write first parameter b := quicvarint.Append(nil, uint64(initialMaxStreamDataBidiLocalParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(0x1337))) b = quicvarint.Append(b, 0x1337) // write a second parameter b = quicvarint.Append(b, uint64(initialMaxStreamDataBidiRemoteParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(0x42))) b = quicvarint.Append(b, 0x42) // write first parameter again b = quicvarint.Append(b, uint64(initialMaxStreamDataBidiLocalParameterID)) b = quicvarint.Append(b, uint64(quicvarint.Len(0x1337))) b = quicvarint.Append(b, 0x1337) b = appendInitialSourceConnectionID(b) err := (&TransportParameters{}).Unmarshal(b, protocol.PerspectiveClient) require.Error(t, err) transportErr, ok := err.(*qerr.TransportError) require.True(t, ok) require.Equal(t, qerr.TransportParameterError, transportErr.ErrorCode) require.Equal(t, fmt.Sprintf("received duplicate transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), transportErr.ErrorMessage) } func TestTransportParameterPreferredAddress(t *testing.T) { testCases := []struct { name string hasIPv4 bool hasIPv6 bool }{ {"IPv4 and IPv6", true, true}, {"IPv4 only", true, false}, {"IPv6 only", false, true}, {"neither IPv4 nor IPv6", false, false}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { testTransportParameterPreferredAddress(t, tc.hasIPv4, tc.hasIPv6) }) } } func testTransportParameterPreferredAddress(t *testing.T, hasIPv4, hasIPv6 bool) { addr4 := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42) addr6 := netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13) pa := &PreferredAddress{ ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, } if hasIPv4 { pa.IPv4 = addr4 } if hasIPv6 { pa.IPv6 = addr6 } data := (&TransportParameters{ PreferredAddress: pa, StatelessResetToken: &protocol.StatelessResetToken{}, ActiveConnectionIDLimit: 2, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} require.NoError(t, p.Unmarshal(data, protocol.PerspectiveServer)) if hasIPv4 { require.True(t, p.PreferredAddress.IPv4.IsValid()) require.Equal(t, addr4, p.PreferredAddress.IPv4) } else { require.False(t, p.PreferredAddress.IPv4.IsValid()) } if hasIPv6 { require.True(t, p.PreferredAddress.IPv6.IsValid()) require.Equal(t, addr6, p.PreferredAddress.IPv6) } else { require.False(t, p.PreferredAddress.IPv6.IsValid()) } require.Equal(t, pa.ConnectionID, p.PreferredAddress.ConnectionID) require.Equal(t, pa.StatelessResetToken, p.PreferredAddress.StatelessResetToken) } func TestTransportParameterPreferredAddressFromClient(t *testing.T) { b := quicvarint.Append(nil, uint64(preferredAddressParameterID)) b = quicvarint.Append(b, 6) b = append(b, []byte("foobar")...) p := &TransportParameters{} err := p.Unmarshal(b, protocol.PerspectiveClient) require.Error(t, err) require.IsType(t, &qerr.TransportError{}, err) transportErr := err.(*qerr.TransportError) require.Equal(t, qerr.TransportParameterError, transportErr.ErrorCode) require.Equal(t, "client sent a preferred_address", transportErr.ErrorMessage) } func TestTransportParameterPreferredAddressZeroLengthConnectionID(t *testing.T) { pa := &PreferredAddress{ IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42), IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13), ConnectionID: protocol.ParseConnectionID([]byte{}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, } data := (&TransportParameters{ PreferredAddress: pa, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} err := p.Unmarshal(data, protocol.PerspectiveServer) require.Error(t, err) require.IsType(t, &qerr.TransportError{}, err) transportErr := err.(*qerr.TransportError) require.Equal(t, qerr.TransportParameterError, transportErr.ErrorCode) require.Equal(t, "invalid connection ID length: 0", transportErr.ErrorMessage) } func TestPreferredAddressErrorOnEOF(t *testing.T) { raw := []byte{ 127, 0, 0, 1, // IPv4 0, 42, // IPv4 Port 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // IPv6 13, 37, // IPv6 Port, 4, // conn ID len 0xde, 0xad, 0xbe, 0xef, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, // stateless reset token } for i := 1; i < len(raw); i++ { b := quicvarint.Append(nil, uint64(preferredAddressParameterID)) b = append(b, raw[:i]...) p := &TransportParameters{} err := p.Unmarshal(b, protocol.PerspectiveServer) require.Error(t, err) } } func TestTransportParametersFromSessionTicket(t *testing.T) { params := &TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), InitialMaxData: protocol.ByteCount(getRandomValue()), MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(uint64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(uint64(protocol.MaxStreamCount))), ActiveConnectionIDLimit: 2 + getRandomValueUpTo(quicvarint.Max-2), MaxDatagramFrameSize: protocol.ByteCount(getRandomValueUpTo(uint64(MaxDatagramSize))), EnableResetStreamAt: getRandomValue()%2 == 0, } require.True(t, params.ValidFor0RTT(params)) b := params.MarshalForSessionTicket(nil) var tp TransportParameters require.NoError(t, tp.UnmarshalFromSessionTicket(b)) require.Equal(t, params.InitialMaxStreamDataBidiLocal, tp.InitialMaxStreamDataBidiLocal) require.Equal(t, params.InitialMaxStreamDataBidiRemote, tp.InitialMaxStreamDataBidiRemote) require.Equal(t, params.InitialMaxStreamDataUni, tp.InitialMaxStreamDataUni) require.Equal(t, params.InitialMaxData, tp.InitialMaxData) require.Equal(t, params.MaxBidiStreamNum, tp.MaxBidiStreamNum) require.Equal(t, params.MaxUniStreamNum, tp.MaxUniStreamNum) require.Equal(t, params.ActiveConnectionIDLimit, tp.ActiveConnectionIDLimit) require.Equal(t, params.MaxDatagramFrameSize, tp.MaxDatagramFrameSize) require.Equal(t, params.EnableResetStreamAt, tp.EnableResetStreamAt) } func TestSessionTicketInvalidTransportParameters(t *testing.T) { var p TransportParameters require.Error(t, p.UnmarshalFromSessionTicket([]byte("foobar"))) } func TestSessionTicketTransportParameterVersionMismatch(t *testing.T) { var p TransportParameters data := p.MarshalForSessionTicket(nil) b := quicvarint.Append(nil, transportParameterMarshalingVersion+1) b = append(b, data[quicvarint.Len(transportParameterMarshalingVersion):]...) err := p.UnmarshalFromSessionTicket(b) require.EqualError(t, err, fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)) } func TestTransportParametersValidFor0RTT(t *testing.T) { saved := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiRemote: 2, InitialMaxStreamDataUni: 3, InitialMaxData: 4, MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, MaxDatagramFrameSize: 1000, } tests := []struct { name string modify func(*TransportParameters) valid bool }{ { name: "No Changes", modify: func(p *TransportParameters) {}, valid: true, }, { name: "InitialMaxStreamDataBidiLocal reduced", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 }, valid: false, }, { name: "InitialMaxStreamDataBidiLocal increased", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 }, valid: true, }, { name: "InitialMaxStreamDataBidiRemote reduced", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 }, valid: false, }, { name: "InitialMaxStreamDataBidiRemote increased", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 }, valid: true, }, { name: "InitialMaxStreamDataUni reduced", modify: func(p *TransportParameters) { p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 }, valid: false, }, { name: "InitialMaxStreamDataUni increased", modify: func(p *TransportParameters) { p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 }, valid: true, }, { name: "InitialMaxData reduced", modify: func(p *TransportParameters) { p.InitialMaxData = saved.InitialMaxData - 1 }, valid: false, }, { name: "InitialMaxData increased", modify: func(p *TransportParameters) { p.InitialMaxData = saved.InitialMaxData + 1 }, valid: true, }, { name: "MaxBidiStreamNum reduced", modify: func(p *TransportParameters) { p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 }, valid: false, }, { name: "MaxBidiStreamNum increased", modify: func(p *TransportParameters) { p.MaxBidiStreamNum = saved.MaxBidiStreamNum + 1 }, valid: true, }, { name: "MaxUniStreamNum reduced", modify: func(p *TransportParameters) { p.MaxUniStreamNum = saved.MaxUniStreamNum - 1 }, valid: false, }, { name: "MaxUniStreamNum increased", modify: func(p *TransportParameters) { p.MaxUniStreamNum = saved.MaxUniStreamNum + 1 }, valid: true, }, { name: "ActiveConnectionIDLimit changed", modify: func(p *TransportParameters) { p.ActiveConnectionIDLimit = 0 }, valid: false, }, { name: "MaxDatagramFrameSize increased", modify: func(p *TransportParameters) { p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 }, valid: true, }, { name: "MaxDatagramFrameSize reduced", modify: func(p *TransportParameters) { p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 }, valid: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := *saved tt.modify(&p) require.Equal(t, tt.valid, p.ValidFor0RTT(saved)) }) } } func TestTransportParametersValidAfter0RTT(t *testing.T) { saved := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiRemote: 2, InitialMaxStreamDataUni: 3, InitialMaxData: 4, MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, MaxDatagramFrameSize: 1000, } tests := []struct { name string modify func(*TransportParameters) reject bool }{ { name: "no changes", modify: func(p *TransportParameters) {}, reject: false, }, { name: "InitialMaxStreamDataBidiLocal reduced", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 }, reject: true, }, { name: "InitialMaxStreamDataBidiLocal increased", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 }, reject: false, }, { name: "InitialMaxStreamDataBidiRemote reduced", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 }, reject: true, }, { name: "InitialMaxStreamDataBidiRemote increased", modify: func(p *TransportParameters) { p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 }, reject: false, }, { name: "InitialMaxStreamDataUni reduced", modify: func(p *TransportParameters) { p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 }, reject: true, }, { name: "InitialMaxStreamDataUni increased", modify: func(p *TransportParameters) { p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 }, reject: false, }, { name: "InitialMaxData reduced", modify: func(p *TransportParameters) { p.InitialMaxData = saved.InitialMaxData - 1 }, reject: true, }, { name: "InitialMaxData increased", modify: func(p *TransportParameters) { p.InitialMaxData = saved.InitialMaxData + 1 }, reject: false, }, { name: "MaxBidiStreamNum reduced", modify: func(p *TransportParameters) { p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 }, reject: true, }, { name: "MaxBidiStreamNum increased", modify: func(p *TransportParameters) { p.MaxBidiStreamNum = saved.MaxBidiStreamNum + 1 }, reject: false, }, { name: "MaxUniStreamNum reduced", modify: func(p *TransportParameters) { p.MaxUniStreamNum = saved.MaxUniStreamNum - 1 }, reject: true, }, { name: "MaxUniStreamNum increased", modify: func(p *TransportParameters) { p.MaxUniStreamNum = saved.MaxUniStreamNum + 1 }, reject: false, }, { name: "ActiveConnectionIDLimit reduced", modify: func(p *TransportParameters) { p.ActiveConnectionIDLimit = saved.ActiveConnectionIDLimit - 1 }, reject: true, }, { name: "ActiveConnectionIDLimit increased", modify: func(p *TransportParameters) { p.ActiveConnectionIDLimit = saved.ActiveConnectionIDLimit + 1 }, reject: false, }, { name: "MaxDatagramFrameSize reduced", modify: func(p *TransportParameters) { p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 }, reject: true, }, { name: "MaxDatagramFrameSize increased", modify: func(p *TransportParameters) { p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 }, reject: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := *saved tt.modify(&p) if tt.reject { require.False(t, p.ValidForUpdate(saved)) } else { require.True(t, p.ValidForUpdate(saved)) } }) } } func BenchmarkTransportParameters(b *testing.B) { b.Run("without preferred address", func(b *testing.B) { benchmarkTransportParameters(b, false) }) b.Run("with preferred address", func(b *testing.B) { benchmarkTransportParameters(b, true) }) } func benchmarkTransportParameters(b *testing.B, withPreferredAddress bool) { b.ReportAllocs() var token protocol.StatelessResetToken rand.Read(token[:]) rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) params := &TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), InitialMaxData: protocol.ByteCount(getRandomValue()), MaxIdleTimeout: 0xcafe * time.Second, MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(uint64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(uint64(protocol.MaxStreamCount))), DisableActiveMigration: true, StatelessResetToken: &token, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), RetrySourceConnectionID: &rcid, AckDelayExponent: 13, MaxAckDelay: 42 * time.Millisecond, ActiveConnectionIDLimit: 2 + getRandomValueUpTo(quicvarint.Max-2), MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), } var token2 protocol.StatelessResetToken rand.Read(token2[:]) if withPreferredAddress { var ip4 [4]byte var ip6 [16]byte rand.Read(ip4[:]) rand.Read(ip6[:]) params.PreferredAddress = &PreferredAddress{ IPv4: netip.AddrPortFrom(netip.AddrFrom4(ip4), 1234), IPv6: netip.AddrPortFrom(netip.AddrFrom16(ip6), 4321), ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: token2, } } data := params.Marshal(protocol.PerspectiveServer) var p TransportParameters for b.Loop() { if err := p.Unmarshal(data, protocol.PerspectiveServer); err != nil { b.Fatal(err) } // check a few fields if p.DisableActiveMigration != params.DisableActiveMigration || p.InitialMaxStreamDataBidiLocal != params.InitialMaxStreamDataBidiLocal || *p.StatelessResetToken != *params.StatelessResetToken || p.AckDelayExponent != params.AckDelayExponent { b.Fatalf("params mismatch: %v vs %v", p, params) } if withPreferredAddress && *p.PreferredAddress != *params.PreferredAddress { b.Fatalf("preferred address mismatch: %v vs %v", p.PreferredAddress, params.PreferredAddress) } } } quic-go-0.59.0/internal/wire/transport_parameters.go000066400000000000000000000562011513066070600225510ustar00rootroot00000000000000package wire import ( "crypto/rand" "encoding/binary" "errors" "fmt" "io" "math" "net/netip" "slices" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" ) // AdditionalTransportParametersClient are additional transport parameters that will be added // to the client's transport parameters. // This is not intended for production use, but _only_ to increase the size of the ClientHello beyond // the usual size of less than 1 MTU. var AdditionalTransportParametersClient map[uint64][]byte const transportParameterMarshalingVersion = 1 type transportParameterID uint64 const ( originalDestinationConnectionIDParameterID transportParameterID = 0x0 maxIdleTimeoutParameterID transportParameterID = 0x1 statelessResetTokenParameterID transportParameterID = 0x2 maxUDPPayloadSizeParameterID transportParameterID = 0x3 initialMaxDataParameterID transportParameterID = 0x4 initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5 initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6 initialMaxStreamDataUniParameterID transportParameterID = 0x7 initialMaxStreamsBidiParameterID transportParameterID = 0x8 initialMaxStreamsUniParameterID transportParameterID = 0x9 ackDelayExponentParameterID transportParameterID = 0xa maxAckDelayParameterID transportParameterID = 0xb disableActiveMigrationParameterID transportParameterID = 0xc preferredAddressParameterID transportParameterID = 0xd activeConnectionIDLimitParameterID transportParameterID = 0xe initialSourceConnectionIDParameterID transportParameterID = 0xf retrySourceConnectionIDParameterID transportParameterID = 0x10 // RFC 9221 maxDatagramFrameSizeParameterID transportParameterID = 0x20 // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/ resetStreamAtParameterID transportParameterID = 0x17f7586d2cb571 // https://datatracker.ietf.org/doc/draft-ietf-quic-ack-frequency/11/ minAckDelayParameterID transportParameterID = 0xff04de1b ) // PreferredAddress is the value encoding in the preferred_address transport parameter type PreferredAddress struct { IPv4, IPv6 netip.AddrPort ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } // TransportParameters are parameters sent to the peer during the handshake type TransportParameters struct { InitialMaxStreamDataBidiLocal protocol.ByteCount InitialMaxStreamDataBidiRemote protocol.ByteCount InitialMaxStreamDataUni protocol.ByteCount InitialMaxData protocol.ByteCount MaxAckDelay time.Duration AckDelayExponent uint8 DisableActiveMigration bool MaxUDPPayloadSize protocol.ByteCount MaxUniStreamNum protocol.StreamNum MaxBidiStreamNum protocol.StreamNum MaxIdleTimeout time.Duration PreferredAddress *PreferredAddress OriginalDestinationConnectionID protocol.ConnectionID InitialSourceConnectionID protocol.ConnectionID RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters StatelessResetToken *protocol.StatelessResetToken ActiveConnectionIDLimit uint64 MaxDatagramFrameSize protocol.ByteCount // RFC 9221 EnableResetStreamAt bool // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/ MinAckDelay *time.Duration } // Unmarshal the transport parameters func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { if err := p.unmarshal(data, sentBy, false); err != nil { return &qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: err.Error(), } } return nil } func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error { // needed to check that every parameter is only sent at most once parameterIDs := make([]transportParameterID, 0, 32) var ( readOriginalDestinationConnectionID bool readInitialSourceConnectionID bool ) p.AckDelayExponent = protocol.DefaultAckDelayExponent p.MaxAckDelay = protocol.DefaultMaxAckDelay p.MaxDatagramFrameSize = protocol.InvalidByteCount p.ActiveConnectionIDLimit = protocol.DefaultActiveConnectionIDLimit for len(b) > 0 { paramIDInt, l, err := quicvarint.Parse(b) if err != nil { return err } paramID := transportParameterID(paramIDInt) b = b[l:] paramLen, l, err := quicvarint.Parse(b) if err != nil { return err } b = b[l:] if uint64(len(b)) < paramLen { return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen) } parameterIDs = append(parameterIDs, paramID) switch paramID { case maxIdleTimeoutParameterID, maxUDPPayloadSizeParameterID, initialMaxDataParameterID, initialMaxStreamDataBidiLocalParameterID, initialMaxStreamDataBidiRemoteParameterID, initialMaxStreamDataUniParameterID, initialMaxStreamsBidiParameterID, initialMaxStreamsUniParameterID, maxAckDelayParameterID, maxDatagramFrameSizeParameterID, ackDelayExponentParameterID, activeConnectionIDLimitParameterID, minAckDelayParameterID: if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil { return err } b = b[paramLen:] case preferredAddressParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent a preferred_address") } if err := p.readPreferredAddress(b, int(paramLen)); err != nil { return err } b = b[paramLen:] case disableActiveMigrationParameterID: if paramLen != 0 { return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) } p.DisableActiveMigration = true case statelessResetTokenParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent a stateless_reset_token") } if paramLen != 16 { return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) } var token protocol.StatelessResetToken if len(b) < len(token) { return io.EOF } copy(token[:], b) b = b[len(token):] p.StatelessResetToken = &token case originalDestinationConnectionIDParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent an original_destination_connection_id") } if paramLen > protocol.MaxConnIDLen { return protocol.ErrInvalidConnectionIDLen } p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen]) b = b[paramLen:] readOriginalDestinationConnectionID = true case initialSourceConnectionIDParameterID: if paramLen > protocol.MaxConnIDLen { return protocol.ErrInvalidConnectionIDLen } p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen]) b = b[paramLen:] readInitialSourceConnectionID = true case retrySourceConnectionIDParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent a retry_source_connection_id") } if paramLen > protocol.MaxConnIDLen { return protocol.ErrInvalidConnectionIDLen } connID := protocol.ParseConnectionID(b[:paramLen]) b = b[paramLen:] p.RetrySourceConnectionID = &connID case resetStreamAtParameterID: if paramLen != 0 { return fmt.Errorf("wrong length for reset_stream_at: %d (expected empty)", paramLen) } p.EnableResetStreamAt = true default: b = b[paramLen:] } } // min_ack_delay must be less or equal to max_ack_delay if p.MinAckDelay != nil && *p.MinAckDelay > p.MaxAckDelay { return fmt.Errorf("min_ack_delay (%s) is greater than max_ack_delay (%s)", *p.MinAckDelay, p.MaxAckDelay) } if !fromSessionTicket { if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { return errors.New("missing original_destination_connection_id") } if p.MaxUDPPayloadSize == 0 { p.MaxUDPPayloadSize = protocol.MaxByteCount } if !readInitialSourceConnectionID { return errors.New("missing initial_source_connection_id") } } // check that every transport parameter was sent at most once slices.SortFunc(parameterIDs, func(a, b transportParameterID) int { if a < b { return -1 } return 1 }) for i := 0; i < len(parameterIDs)-1; i++ { if parameterIDs[i] == parameterIDs[i+1] { return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) } } return nil } func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error { remainingLen := len(b) pa := &PreferredAddress{} if len(b) < 4+2+16+2+1 { return io.EOF } var ipv4 [4]byte copy(ipv4[:], b[:4]) port4 := binary.BigEndian.Uint16(b[4:]) b = b[4+2:] if port4 != 0 && ipv4 != [4]byte{} { pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4) } var ipv6 [16]byte copy(ipv6[:], b[:16]) port6 := binary.BigEndian.Uint16(b[16:]) if port6 != 0 && ipv6 != [16]byte{} { pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6) } b = b[16+2:] connIDLen := int(b[0]) b = b[1:] if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { return fmt.Errorf("invalid connection ID length: %d", connIDLen) } if len(b) < connIDLen+len(pa.StatelessResetToken) { return io.EOF } pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen]) b = b[connIDLen:] copy(pa.StatelessResetToken[:], b) b = b[len(pa.StatelessResetToken):] if bytesRead := remainingLen - len(b); bytesRead != expectedLen { return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) } p.PreferredAddress = pa return nil } func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error { val, l, err := quicvarint.Parse(b) if err != nil { return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) } if l != expectedLen { return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) } //nolint:exhaustive // This only covers the numeric transport parameters. switch paramID { case initialMaxStreamDataBidiLocalParameterID: p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val) case initialMaxStreamDataBidiRemoteParameterID: p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val) case initialMaxStreamDataUniParameterID: p.InitialMaxStreamDataUni = protocol.ByteCount(val) case initialMaxDataParameterID: p.InitialMaxData = protocol.ByteCount(val) case initialMaxStreamsBidiParameterID: p.MaxBidiStreamNum = protocol.StreamNum(val) if p.MaxBidiStreamNum > protocol.MaxStreamCount { return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) } case initialMaxStreamsUniParameterID: p.MaxUniStreamNum = protocol.StreamNum(val) if p.MaxUniStreamNum > protocol.MaxStreamCount { return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) } case maxIdleTimeoutParameterID: p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) case maxUDPPayloadSizeParameterID: if val < 1200 { return fmt.Errorf("invalid value for max_udp_payload_size: %d (minimum 1200)", val) } p.MaxUDPPayloadSize = protocol.ByteCount(val) case ackDelayExponentParameterID: if val > protocol.MaxAckDelayExponent { return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) } p.AckDelayExponent = uint8(val) case maxAckDelayParameterID: if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) { return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond) } p.MaxAckDelay = time.Duration(val) * time.Millisecond case activeConnectionIDLimitParameterID: if val < 2 { return fmt.Errorf("invalid value for active_connection_id_limit: %d (minimum 2)", val) } p.ActiveConnectionIDLimit = val case maxDatagramFrameSizeParameterID: p.MaxDatagramFrameSize = protocol.ByteCount(val) case minAckDelayParameterID: mad := time.Duration(val) * time.Microsecond if mad < 0 { mad = math.MaxInt64 } p.MinAckDelay = &mad default: return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) } return nil } // Marshal the transport parameters func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { // Typical Transport Parameters consume around 110 bytes, depending on the exact values, // especially the lengths of the Connection IDs. // Allocate 256 bytes, so we won't have to grow the slice in any case. b := make([]byte, 0, 256) // add a greased value random := make([]byte, 18) rand.Read(random) b = quicvarint.Append(b, 27+31*uint64(random[0])) length := random[1] % 16 b = quicvarint.Append(b, uint64(length)) b = append(b, random[2:2+length]...) // initial_max_stream_data_bidi_local b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) // idle_timeout b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) // max_udp_payload_size if p.MaxUDPPayloadSize > 0 { b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(p.MaxUDPPayloadSize)) } // max_ack_delay // Only send it if is different from the default value. if p.MaxAckDelay != protocol.DefaultMaxAckDelay { b = p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) } // ack_delay_exponent // Only send it if is different from the default value. if p.AckDelayExponent != protocol.DefaultAckDelayExponent { b = p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) } // disable_active_migration if p.DisableActiveMigration { b = quicvarint.Append(b, uint64(disableActiveMigrationParameterID)) b = quicvarint.Append(b, 0) } if pers == protocol.PerspectiveServer { // stateless_reset_token if p.StatelessResetToken != nil { b = quicvarint.Append(b, uint64(statelessResetTokenParameterID)) b = quicvarint.Append(b, 16) b = append(b, p.StatelessResetToken[:]...) } // original_destination_connection_id b = quicvarint.Append(b, uint64(originalDestinationConnectionIDParameterID)) b = quicvarint.Append(b, uint64(p.OriginalDestinationConnectionID.Len())) b = append(b, p.OriginalDestinationConnectionID.Bytes()...) // preferred_address if p.PreferredAddress != nil { b = quicvarint.Append(b, uint64(preferredAddressParameterID)) b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) if p.PreferredAddress.IPv4.IsValid() { ipv4 := p.PreferredAddress.IPv4.Addr().As4() b = append(b, ipv4[:]...) b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port()) } else { b = append(b, make([]byte, 6)...) } if p.PreferredAddress.IPv6.IsValid() { ipv6 := p.PreferredAddress.IPv6.Addr().As16() b = append(b, ipv6[:]...) b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port()) } else { b = append(b, make([]byte, 18)...) } b = append(b, uint8(p.PreferredAddress.ConnectionID.Len())) b = append(b, p.PreferredAddress.ConnectionID.Bytes()...) b = append(b, p.PreferredAddress.StatelessResetToken[:]...) } } // active_connection_id_limit if p.ActiveConnectionIDLimit != protocol.DefaultActiveConnectionIDLimit { b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) } // initial_source_connection_id b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID)) b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len())) b = append(b, p.InitialSourceConnectionID.Bytes()...) // retry_source_connection_id if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { b = quicvarint.Append(b, uint64(retrySourceConnectionIDParameterID)) b = quicvarint.Append(b, uint64(p.RetrySourceConnectionID.Len())) b = append(b, p.RetrySourceConnectionID.Bytes()...) } // QUIC datagrams if p.MaxDatagramFrameSize != protocol.InvalidByteCount { b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) } // QUIC Stream Resets with Partial Delivery if p.EnableResetStreamAt { b = quicvarint.Append(b, uint64(resetStreamAtParameterID)) b = quicvarint.Append(b, 0) } if p.MinAckDelay != nil { b = p.marshalVarintParam(b, minAckDelayParameterID, uint64(*p.MinAckDelay/time.Microsecond)) } if pers == protocol.PerspectiveClient && len(AdditionalTransportParametersClient) > 0 { for k, v := range AdditionalTransportParametersClient { b = quicvarint.Append(b, k) b = quicvarint.Append(b, uint64(len(v))) b = append(b, v...) } } return b } func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameterID, val uint64) []byte { b = quicvarint.Append(b, uint64(id)) b = quicvarint.Append(b, uint64(quicvarint.Len(val))) return quicvarint.Append(b, val) } // MarshalForSessionTicket marshals the transport parameters we save in the session ticket. // When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. // The client will remember the transport parameters used in the last session, // and apply those to the 0-RTT data it sends. // Saving the transport parameters in the ticket gives the server the option to reject 0-RTT // if the transport parameters changed. // Since the session ticket is encrypted, the serialization format is defined by the server. // For convenience, we use the same format that we also use for sending the transport parameters. func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { b = quicvarint.Append(b, transportParameterMarshalingVersion) // initial_max_stream_data_bidi_local b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) // active_connection_id_limit b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) // max_datagram_frame_size if p.MaxDatagramFrameSize != protocol.InvalidByteCount { b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) } // reset_stream_at if p.EnableResetStreamAt { b = quicvarint.Append(b, uint64(resetStreamAtParameterID)) b = quicvarint.Append(b, 0) } return b } // UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error { version, l, err := quicvarint.Parse(b) if err != nil { return err } if version != transportParameterMarshalingVersion { return fmt.Errorf("unknown transport parameter marshaling version: %d", version) } return p.unmarshal(b[l:], protocol.PerspectiveServer, true) } // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { return false } return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && p.InitialMaxData >= saved.InitialMaxData && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum && p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit } // ValidForUpdate checks that the new transport parameters don't reduce limits after resuming a 0-RTT connection. // It is only used on the client side. func (p *TransportParameters) ValidForUpdate(saved *TransportParameters) bool { if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { return false } return p.ActiveConnectionIDLimit >= saved.ActiveConnectionIDLimit && p.InitialMaxData >= saved.InitialMaxData && p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum } // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, " logParams := []any{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID} if p.RetrySourceConnectionID != nil { logString += "RetrySourceConnectionID: %s, " logParams = append(logParams, p.RetrySourceConnectionID) } logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d" logParams = append(logParams, []any{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...) if p.StatelessResetToken != nil { // the client never sends a stateless reset token logString += ", StatelessResetToken: %#x" logParams = append(logParams, *p.StatelessResetToken) } if p.MaxDatagramFrameSize != protocol.InvalidByteCount { logString += ", MaxDatagramFrameSize: %d" logParams = append(logParams, p.MaxDatagramFrameSize) } logString += ", EnableResetStreamAt: %t" logParams = append(logParams, p.EnableResetStreamAt) if p.MinAckDelay != nil { logString += ", MinAckDelay: %s" logParams = append(logParams, *p.MinAckDelay) } logString += "}" return fmt.Sprintf(logString, logParams...) } quic-go-0.59.0/internal/wire/version_negotiation.go000066400000000000000000000043051513066070600223550ustar00rootroot00000000000000package wire import ( "crypto/rand" "encoding/binary" "errors" "github.com/quic-go/quic-go/internal/protocol" ) // ParseVersionNegotiationPacket parses a Version Negotiation packet. func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) { n, dest, src, err := ParseArbitraryLenConnectionIDs(b) if err != nil { return nil, nil, nil, err } b = b[n:] if len(b) == 0 { //nolint:staticcheck // SA1021: the packet is called Version Negotiation packet return nil, nil, nil, errors.New("Version Negotiation packet has empty version list") } if len(b)%4 != 0 { //nolint:staticcheck // SA1021: the packet is called Version Negotiation packet return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") } versions := make([]protocol.Version, len(b)/4) for i := 0; len(b) > 0; i++ { versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4])) b = b[4:] } return dest, src, versions, nil } // ComposeVersionNegotiation composes a Version Negotiation func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 buf := make([]byte, 1+4 /* type byte and version field */, expectedLen) _, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here. // Setting the "QUIC bit" (0x40) is not required by the RFC, // but it allows clients to demultiplex QUIC with a long list of other protocols. // See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details. buf[0] |= 0xc0 // The next 4 bytes are left at 0 (version number). buf = append(buf, uint8(destConnID.Len())) buf = append(buf, destConnID.Bytes()...) buf = append(buf, uint8(srcConnID.Len())) buf = append(buf, srcConnID.Bytes()...) for _, v := range greasedVersions { buf = binary.BigEndian.AppendUint32(buf, uint32(v)) } return buf } quic-go-0.59.0/internal/wire/version_negotiation_test.go000066400000000000000000000072621513066070600234210ustar00rootroot00000000000000package wire import ( "crypto/rand" "encoding/binary" mrand "math/rand/v2" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestParseVersionNegotiationPacket(t *testing.T) { randConnID := func(l int) protocol.ArbitraryLenConnectionID { b := make(protocol.ArbitraryLenConnectionID, l) _, err := rand.Read(b) require.NoError(t, err) return b } srcConnID := randConnID(mrand.IntN(255) + 1) destConnID := randConnID(mrand.IntN(255) + 1) versions := []protocol.Version{0x22334455, 0x33445566} data := []byte{0x80, 0, 0, 0, 0} data = append(data, uint8(len(destConnID))) data = append(data, destConnID...) data = append(data, uint8(len(srcConnID))) data = append(data, srcConnID...) for _, v := range versions { data = append(data, []byte{0, 0, 0, 0}...) binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) } require.True(t, IsVersionNegotiationPacket(data)) dest, src, supportedVersions, err := ParseVersionNegotiationPacket(data) require.NoError(t, err) require.Equal(t, destConnID, dest) require.Equal(t, srcConnID, src) require.Equal(t, versions, supportedVersions) } func TestParseVersionNegotiationPacketWithInvalidLength(t *testing.T) { connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.Version{0x22334455, 0x33445566} data := ComposeVersionNegotiation(connID, connID, versions) _, _, _, err := ParseVersionNegotiationPacket(data[:len(data)-2]) require.EqualError(t, err, "Version Negotiation packet has a version list with an invalid length") } func TestParseVersionNegotiationPacketEmptyVersions(t *testing.T) { connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.Version{0x22334455} data := ComposeVersionNegotiation(connID, connID, versions) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] _, _, _, err := ParseVersionNegotiationPacket(data) require.EqualError(t, err, "Version Negotiation packet has empty version list") } func TestComposeVersionNegotiationWithReservedVersion(t *testing.T) { srcConnID := protocol.ArbitraryLenConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} destConnID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.Version{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) require.True(t, IsLongHeaderPacket(data[0])) require.NotZero(t, data[0]&0x40) v, err := ParseVersion(data) require.NoError(t, err) require.Zero(t, v) dest, src, supportedVersions, err := ParseVersionNegotiationPacket(data) require.NoError(t, err) require.Equal(t, destConnID, dest) require.Equal(t, srcConnID, src) // the supported versions should include one reserved version number require.Len(t, supportedVersions, len(versions)+1) for _, v := range versions { require.Contains(t, supportedVersions, v) } var reservedVersion protocol.Version versionLoop: for _, ver := range supportedVersions { for _, v := range versions { if v == ver { continue versionLoop } } reservedVersion = ver } require.NotZero(t, reservedVersion) require.True(t, reservedVersion&0x0f0f0f0f == 0x0a0a0a0a) // check that it's a greased version number } func BenchmarkComposeVersionNegotiationPacket(b *testing.B) { b.ReportAllocs() supportedVersions := []protocol.Version{protocol.Version2, protocol.Version1, 0x1337} destConnID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0xa, 0xb, 0xc, 0xd} srcConnID := protocol.ArbitraryLenConnectionID{10, 9, 8, 7, 6, 5, 4, 3, 2, 1} for b.Loop() { ComposeVersionNegotiation(destConnID, srcConnID, supportedVersions) } } quic-go-0.59.0/interop/000077500000000000000000000000001513066070600146355ustar00rootroot00000000000000quic-go-0.59.0/interop/Dockerfile000066400000000000000000000022341513066070600166300ustar00rootroot00000000000000FROM martenseemann/quic-network-simulator-endpoint:latest AS builder ARG TARGETPLATFORM RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}" RUN apt-get update && apt-get install -y wget tar git && rm -rf /var/lib/apt/lists/* ENV GOVERSION=1.25.0 RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ filename="go${GOVERSION}.${platform}.tar.gz" && \ wget https://dl.google.com/go/${filename} && \ tar xfz ${filename} && \ rm ${filename} ENV PATH="/go/bin:${PATH}" # build with --build-arg CACHEBUST=$(date +%s) ARG CACHEBUST=1 COPY . /quic-go WORKDIR /quic-go RUN git rev-parse HEAD | tee commit.txt RUN go build -o server -ldflags="-X github.com/quic-go/quic-go/qlog.quicGoVersion=$(git describe --always --long --dirty)" interop/server/main.go RUN go build -o client -ldflags="-X github.com/quic-go/quic-go/qlog.quicGoVersion=$(git describe --always --long --dirty)" interop/client/main.go FROM martenseemann/quic-network-simulator-endpoint:latest WORKDIR /quic-go COPY --from=builder /quic-go/commit.txt /quic-go/server /quic-go/client ./ COPY --from=builder /quic-go/interop/run_endpoint.sh ./ RUN chmod +x run_endpoint.sh ENTRYPOINT [ "./run_endpoint.sh" ] quic-go-0.59.0/interop/client/000077500000000000000000000000001513066070600161135ustar00rootroot00000000000000quic-go-0.59.0/interop/client/main.go000066400000000000000000000113371513066070600173730ustar00rootroot00000000000000package main import ( "crypto/tls" "errors" "flag" "fmt" "io" "log" "net/http" "os" "strings" "time" "golang.org/x/sync/errgroup" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/interop/http09" "github.com/quic-go/quic-go/interop/utils" ) var errUnsupported = errors.New("unsupported test case") var tlsConf *tls.Config func main() { logFile, err := os.Create("/logs/log.txt") if err != nil { fmt.Printf("Could not create log file: %s\n", err.Error()) os.Exit(1) } defer logFile.Close() log.SetOutput(logFile) keyLog, err := utils.GetSSLKeyLog() if err != nil { fmt.Printf("Could not create key log: %s\n", err.Error()) os.Exit(1) } if keyLog != nil { defer keyLog.Close() } tlsConf = &tls.Config{ InsecureSkipVerify: true, KeyLogWriter: keyLog, } testcase := os.Getenv("TESTCASE") if err := runTestcase(testcase); err != nil { if err == errUnsupported { fmt.Printf("unsupported test case: %s\n", testcase) os.Exit(127) } fmt.Printf("Downloading files failed: %s\n", err.Error()) os.Exit(1) } } func runTestcase(testcase string) error { flag.Parse() urls := flag.Args() quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer} if testcase == "http3" { r := &http3.Transport{ TLSClientConfig: tlsConf, QUICConfig: quicConf, } defer r.Close() return downloadFiles(r, urls, false) } r := &http09.RoundTripper{ TLSClientConfig: tlsConf, QuicConfig: quicConf, } defer r.Close() switch testcase { case "handshake", "transfer", "retry": case "keyupdate": handshake.FirstKeyUpdateInterval = 100 case "chacha20": reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) defer reset() case "multiconnect": return runMultiConnectTest(r, urls) case "versionnegotiation": return runVersionNegotiationTest(r, urls) case "resumption": return runResumptionTest(r, urls, false) case "zerortt": return runResumptionTest(r, urls, true) default: return errUnsupported } return downloadFiles(r, urls, false) } func runVersionNegotiationTest(r *http09.RoundTripper, urls []string) error { if len(urls) != 1 { return errors.New("expected at least 2 URLs") } protocol.SupportedVersions = []protocol.Version{0x1a2a3a4a} err := downloadFile(r, urls[0], false) if err == nil { return errors.New("expected version negotiation to fail") } if !strings.Contains(err.Error(), "No compatible QUIC version found") { return fmt.Errorf("expect version negotiation error, got: %s", err.Error()) } return nil } func runMultiConnectTest(r *http09.RoundTripper, urls []string) error { for _, url := range urls { if err := downloadFile(r, url, false); err != nil { return err } if err := r.Close(); err != nil { return err } } return nil } type sessionCache struct { tls.ClientSessionCache put chan<- struct{} } func newSessionCache(c tls.ClientSessionCache) (tls.ClientSessionCache, <-chan struct{}) { put := make(chan struct{}, 100) return &sessionCache{ClientSessionCache: c, put: put}, put } func (c *sessionCache) Put(key string, cs *tls.ClientSessionState) { c.ClientSessionCache.Put(key, cs) c.put <- struct{}{} } func runResumptionTest(r *http09.RoundTripper, urls []string, use0RTT bool) error { if len(urls) < 2 { return errors.New("expected at least 2 URLs") } var put <-chan struct{} tlsConf.ClientSessionCache, put = newSessionCache(tls.NewLRUClientSessionCache(1)) // do the first transfer if err := downloadFiles(r, urls[:1], false); err != nil { return err } // wait for the session ticket to arrive select { case <-time.NewTimer(10 * time.Second).C: return errors.New("expected to receive a session ticket within 10 seconds") case <-put: } if err := r.Close(); err != nil { return err } // reestablish the connection, using the session ticket that the server (hopefully provided) defer r.Close() return downloadFiles(r, urls[1:], use0RTT) } func downloadFiles(cl http.RoundTripper, urls []string, use0RTT bool) error { var g errgroup.Group for _, u := range urls { url := u g.Go(func() error { return downloadFile(cl, url, use0RTT) }) } return g.Wait() } func downloadFile(cl http.RoundTripper, url string, use0RTT bool) error { method := http.MethodGet if use0RTT { method = http09.MethodGet0RTT } req, err := http.NewRequest(method, url, nil) if err != nil { return err } rsp, err := cl.RoundTrip(req) if err != nil { return err } defer rsp.Body.Close() file, err := os.Create("/downloads" + req.URL.Path) if err != nil { return err } defer file.Close() _, err = io.Copy(file, rsp.Body) return err } quic-go-0.59.0/interop/http09/000077500000000000000000000000001513066070600157655ustar00rootroot00000000000000quic-go-0.59.0/interop/http09/client.go000066400000000000000000000065631513066070600176040ustar00rootroot00000000000000package http09 import ( "context" "crypto/tls" "errors" "io" "log" "net" "net/http" "strings" "sync" "golang.org/x/net/idna" "github.com/quic-go/quic-go" ) // MethodGet0RTT allows a GET request to be sent using 0-RTT. // Note that 0-RTT data doesn't provide replay protection. const MethodGet0RTT = "GET_0RTT" // RoundTripper performs HTTP/0.9 roundtrips over QUIC. type RoundTripper struct { mutex sync.Mutex TLSClientConfig *tls.Config QuicConfig *quic.Config clients map[string]*client } var _ http.RoundTripper = &RoundTripper{} // RoundTrip performs a HTTP/0.9 request. // It only supports GET requests. func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if req.Method != http.MethodGet && req.Method != MethodGet0RTT { return nil, errors.New("only GET requests supported") } log.Printf("Requesting %s.\n", req.URL) r.mutex.Lock() hostname := authorityAddr("https", hostnameFromRequest(req)) if r.clients == nil { r.clients = make(map[string]*client) } c, ok := r.clients[hostname] if !ok { tlsConf := &tls.Config{} if r.TLSClientConfig != nil { tlsConf = r.TLSClientConfig.Clone() } tlsConf.NextProtos = []string{NextProto} c = &client{ hostname: hostname, tlsConf: tlsConf, quicConf: r.QuicConfig, } r.clients[hostname] = c } r.mutex.Unlock() return c.RoundTrip(req) } // Close closes the roundtripper. func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() for id, c := range r.clients { if err := c.Close(); err != nil { return err } delete(r.clients, id) } return nil } type client struct { hostname string tlsConf *tls.Config quicConf *quic.Config once sync.Once conn *quic.Conn dialErr error } func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { c.once.Do(func() { c.conn, c.dialErr = quic.DialAddrEarly(context.Background(), c.hostname, c.tlsConf, c.quicConf) }) if c.dialErr != nil { return nil, c.dialErr } if req.Method != MethodGet0RTT { <-c.conn.HandshakeComplete() } return c.doRequest(req) } func (c *client) doRequest(req *http.Request) (*http.Response, error) { str, err := c.conn.OpenStreamSync(context.Background()) if err != nil { return nil, err } cmd := "GET " + req.URL.Path + "\r\n" if _, err := str.Write([]byte(cmd)); err != nil { return nil, err } if err := str.Close(); err != nil { return nil, err } rsp := &http.Response{ Proto: "HTTP/0.9", ProtoMajor: 0, ProtoMinor: 9, Request: req, Body: io.NopCloser(str), } return rsp, nil } func (c *client) Close() error { if c.conn == nil { return nil } return c.conn.CloseWithError(0, "") } func hostnameFromRequest(req *http.Request) string { if req.URL != nil { return req.URL.Host } return "" } // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. func authorityAddr(scheme string, authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port port = "443" if scheme == "http" { port = "80" } host = authority } if a, err := idna.ToASCII(host); err == nil { host = a } // IPv6 address literal, without a port: if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { return host + ":" + port } return net.JoinHostPort(host, port) } quic-go-0.59.0/interop/http09/http_test.go000066400000000000000000000036701513066070600203400ustar00rootroot00000000000000package http09 import ( "crypto/tls" "fmt" "io" "net" "net/http" "net/http/httptest" "testing" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/testdata" "github.com/stretchr/testify/require" ) func startServer(t *testing.T) net.Addr { t.Helper() server := &Server{} conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) tr := &quic.Transport{Conn: conn} tlsConf := testdata.GetTLSConfig() tlsConf.NextProtos = []string{NextProto} ln, err := tr.ListenEarly(tlsConf, &quic.Config{}) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) _ = server.ServeListener(ln) }() t.Cleanup(func() { require.NoError(t, ln.Close()) <-done }) return ln.Addr() } func TestHTTPRequest(t *testing.T) { http.HandleFunc("/helloworld", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("Hello World!")) }) addr := startServer(t) rt := &RoundTripper{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} t.Cleanup(func() { rt.Close() }) req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/helloworld", addr), nil) rsp, err := rt.RoundTrip(req) require.NoError(t, err) data, err := io.ReadAll(rsp.Body) require.NoError(t, err) require.Equal(t, []byte("Hello World!"), data) } func TestHTTPHeaders(t *testing.T) { http.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("foo", "bar") w.WriteHeader(1337) _, _ = w.Write([]byte("done")) }) addr := startServer(t) rt := &RoundTripper{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} t.Cleanup(func() { rt.Close() }) req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/headers", addr), nil) rsp, err := rt.RoundTrip(req) require.NoError(t, err) data, err := io.ReadAll(rsp.Body) require.NoError(t, err) require.Equal(t, []byte("done"), data) // HTTP/0.9 doesn't support HTTP headers } quic-go-0.59.0/interop/http09/server.go000066400000000000000000000044061513066070600176260ustar00rootroot00000000000000package http09 import ( "context" "io" "log" "net/http" "net/url" "runtime" "strings" "github.com/quic-go/quic-go" ) const NextProto = "hq-interop" type responseWriter struct { io.Writer headers http.Header } var _ http.ResponseWriter = &responseWriter{} func (w *responseWriter) Header() http.Header { if w.headers == nil { w.headers = make(http.Header) } return w.headers } func (w *responseWriter) WriteHeader(int) {} // Server is a HTTP/0.9 server listening for QUIC connections. type Server struct { Handler *http.ServeMux } // ServeListener serves HTTP/0.9 on all connections accepted from a QUIC listener. func (s *Server) ServeListener(ln *quic.EarlyListener) error { for { conn, err := ln.Accept(context.Background()) if err != nil { return err } go s.handleConn(conn) } } func (s *Server) handleConn(conn *quic.Conn) { for { str, err := conn.AcceptStream(context.Background()) if err != nil { log.Printf("Error accepting stream: %s\n", err.Error()) return } go func() { if err := s.handleStream(str); err != nil { log.Printf("Handling stream failed: %s\n", err.Error()) } }() } } func (s *Server) handleStream(str *quic.Stream) error { reqBytes, err := io.ReadAll(str) if err != nil { return err } request := string(reqBytes) request = strings.TrimRight(request, "\r\n") request = strings.TrimRight(request, " ") log.Printf("Received request: %s\n", request) if request[:5] != "GET /" { str.CancelWrite(42) return nil } u, err := url.Parse(request[4:]) if err != nil { return err } u.Scheme = "https" req := &http.Request{ Method: http.MethodGet, Proto: "HTTP/0.9", ProtoMajor: 0, ProtoMinor: 9, Body: str, URL: u, } handler := s.Handler if handler == nil { handler = http.DefaultServeMux } var panicked bool func() { defer func() { if p := recover(); p != nil { // Copied from net/http/server.go const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] log.Printf("http: panic serving: %v\n%s", p, buf) panicked = true } }() handler.ServeHTTP(&responseWriter{Writer: str}, req) }() if panicked { if _, err := str.Write([]byte("500")); err != nil { return err } } return str.Close() } quic-go-0.59.0/interop/run_endpoint.sh000066400000000000000000000007441513066070600177020ustar00rootroot00000000000000#!/bin/bash set -e # Set up the routing needed for the simulation. /setup.sh echo "Using commit:" `cat commit.txt` if [ "$ROLE" == "client" ]; then # Wait for the simulator to start up. /wait-for-it.sh sim:57832 -s -t 10 echo "Starting QUIC client..." echo "Client params: $CLIENT_PARAMS" echo "Test case: $TESTCASE" QUIC_GO_LOG_LEVEL=debug ./client $CLIENT_PARAMS $REQUESTS else echo "Running QUIC server." QUIC_GO_LOG_LEVEL=debug ./server "$@" fi quic-go-0.59.0/interop/server/000077500000000000000000000000001513066070600161435ustar00rootroot00000000000000quic-go-0.59.0/interop/server/main.go000066400000000000000000000047531513066070600174270ustar00rootroot00000000000000package main import ( "crypto/tls" "fmt" "log" "net" "net/http" "os" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/interop/http09" "github.com/quic-go/quic-go/interop/utils" ) func main() { logFile, err := os.Create("/logs/log.txt") if err != nil { fmt.Printf("Could not create log file: %s\n", err.Error()) os.Exit(1) } defer logFile.Close() log.SetOutput(logFile) keyLog, err := utils.GetSSLKeyLog() if err != nil { fmt.Printf("Could not create key log: %s\n", err.Error()) os.Exit(1) } if keyLog != nil { defer keyLog.Close() } testcase := os.Getenv("TESTCASE") quicConf := &quic.Config{ Allow0RTT: testcase == "zerortt", Tracer: utils.NewQLOGConnectionTracer, } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") if err != nil { fmt.Println(err) os.Exit(1) } tlsConf := &tls.Config{ Certificates: []tls.Certificate{cert}, KeyLogWriter: keyLog, NextProtos: []string{http09.NextProto}, } switch testcase { case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt": err = runHTTP09Server(tlsConf, quicConf, testcase == "retry") case "chacha20": reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) defer reset() err = runHTTP09Server(tlsConf, quicConf, false) case "http3": tlsConf.NextProtos = []string{http3.NextProtoH3} err = runHTTP3Server(tlsConf, quicConf) default: fmt.Printf("unsupported test case: %s\n", testcase) os.Exit(127) } if err != nil { fmt.Printf("Error running server: %s\n", err.Error()) os.Exit(1) } } func runHTTP09Server(tlsConf *tls.Config, quicConf *quic.Config, forceRetry bool) error { http.DefaultServeMux.Handle("/", http.FileServer(http.Dir("/www"))) server := http09.Server{} udpAddr, err := net.ResolveUDPAddr("udp", ":443") if err != nil { return err } conn, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } tr := &quic.Transport{ Conn: conn, VerifySourceAddress: func(net.Addr) bool { return forceRetry }, } ln, err := tr.ListenEarly(tlsConf, quicConf) if err != nil { return err } return server.ServeListener(ln) } func runHTTP3Server(tlsConf *tls.Config, quicConf *quic.Config) error { server := http3.Server{ Addr: ":443", TLSConfig: tlsConf, QUICConfig: quicConf, } http.DefaultServeMux.Handle("/", http.FileServer(http.Dir("/www"))) return server.ListenAndServe() } quic-go-0.59.0/interop/utils/000077500000000000000000000000001513066070600157755ustar00rootroot00000000000000quic-go-0.59.0/interop/utils/logging.go000066400000000000000000000026301513066070600177530ustar00rootroot00000000000000package utils import ( "bufio" "context" "fmt" "io" "log" "os" "strings" "github.com/quic-go/quic-go" h3qlog "github.com/quic-go/quic-go/http3/qlog" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) // GetSSLKeyLog creates a file for the TLS key log func GetSSLKeyLog() (io.WriteCloser, error) { filename := os.Getenv("SSLKEYLOGFILE") if len(filename) == 0 { return nil, nil } f, err := os.Create(filename) if err != nil { return nil, err } return f, nil } // NewQLOGConnectionTracer create a qlog file in QLOGDIR func NewQLOGConnectionTracer(_ context.Context, isClient bool, connID quic.ConnectionID) qlogwriter.Trace { qlogDir := os.Getenv("QLOGDIR") if len(qlogDir) == 0 { return nil } if _, err := os.Stat(qlogDir); os.IsNotExist(err) { if err := os.MkdirAll(qlogDir, 0o666); err != nil { log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err) } } path := fmt.Sprintf("%s/%s.sqlog", strings.TrimRight(qlogDir, "/"), connID) f, err := os.Create(path) if err != nil { log.Printf("Failed to create qlog file %s: %s", path, err.Error()) return nil } log.Printf("Created qlog file: %s\n", path) fileSeq := qlogwriter.NewConnectionFileSeq( utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), isClient, connID, []string{qlog.EventSchema, h3qlog.EventSchema}, ) go fileSeq.Run() return fileSeq } quic-go-0.59.0/metrics/000077500000000000000000000000001513066070600146235ustar00rootroot00000000000000quic-go-0.59.0/metrics/dashboards/000077500000000000000000000000001513066070600167355ustar00rootroot00000000000000quic-go-0.59.0/metrics/dashboards/README.md000066400000000000000000000015001513066070600202100ustar00rootroot00000000000000# quic-go Prometheus / Grafana Local Development Setup For local development and debugging, it can be useful to spin up a local Prometheus and Grafana instance. Please refer to the [documentation](https://quic-go.net/docs/quic/metrics/) for how to configure quic-go to expose Prometheus metrics. The configuration files in this directory assume that the application exposes the Prometheus endpoint at `http://localhost:5001/prometheus`: ```go import "github.com/prometheus/client_golang/prometheus/promhttp" go func() { http.Handle("/prometheus", promhttp.Handler()) log.Fatal(http.ListenAndServe("localhost:5001", nil)) }() ``` Prometheus and Grafana can be started using Docker Compose: Running: ```shell docker compose up ``` [quic-go.json](./quic-go.json) contains the JSON model of an example Grafana dashboard. quic-go-0.59.0/metrics/dashboards/datasources.yml000066400000000000000000000003171513066070600217760ustar00rootroot00000000000000apiVersion: 1 deleteDatasources: - name: Prometheus orgId: 1 datasources: - name: Prometheus orgId: 1 type: prometheus access: proxy url: http://prometheus:9090 editable: false quic-go-0.59.0/metrics/dashboards/docker-compose.yml000066400000000000000000000011061513066070600223700ustar00rootroot00000000000000version: '3.8' volumes: prometheus_data: {} grafana_data: {} services: prometheus: image: prom/prometheus:latest container_name: prometheus volumes: - ./prometheus.yml:/etc/prometheus/prometheus.yml - prometheus_data:/prometheus command: - '--config.file=/etc/prometheus/prometheus.yml' expose: - 9090 grafana: image: grafana/grafana:latest container_name: grafana volumes: - grafana_data:/var/lib/grafana - ./datasources.yml:/etc/grafana/provisioning/datasources/prom.yml ports: - "3000:3000" quic-go-0.59.0/metrics/dashboards/prometheus.yml000066400000000000000000000003041513066070600216500ustar00rootroot00000000000000global: scrape_interval: 15s scrape_configs: - job_name: 'quic-go' scrape_interval: 15s static_configs: - targets: ['host.docker.internal:5001'] metrics_path: '/prometheus' quic-go-0.59.0/metrics/dashboards/quic-go.json000066400000000000000000000550551513066070600212060ustar00rootroot00000000000000{ "__inputs": [ { "name": "DS_PROMETHEUS", "label": "Prometheus", "description": "", "type": "datasource", "pluginId": "prometheus", "pluginName": "Prometheus" } ], "__elements": {}, "__requires": [ { "type": "grafana", "id": "grafana", "name": "Grafana", "version": "10.2.3" }, { "type": "datasource", "id": "prometheus", "name": "Prometheus", "version": "1.0.0" }, { "type": "panel", "id": "stat", "name": "Stat", "version": "" }, { "type": "panel", "id": "timeseries", "name": "Time series", "version": "" } ], "annotations": { "list": [ { "builtIn": 1, "datasource": { "type": "grafana", "uid": "-- Grafana --" }, "enable": true, "hide": true, "iconColor": "rgba(0, 211, 255, 1)", "name": "Annotations & Alerts", "type": "dashboard" } ] }, "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, "id": null, "links": [], "liveNow": false, "panels": [ { "collapsed": false, "gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 }, "id": 7, "panels": [], "title": "Transport", "type": "row" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] } }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 0, "y": 1 }, "id": 3, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "single", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "sum(rate(quicgo_server_received_packets_dropped_total{instance=~\"$instance\"}[$__rate_interval])) by (reason)", "instant": false, "legendFormat": "__auto", "range": true, "refId": "A" } ], "title": "Server Dropped Packets", "type": "timeseries" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] } }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 12, "y": 1 }, "id": 12, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "single", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "sum(rate(quicgo_server_connections_rejected_total{instance=~\"$instance\"}[$__rate_interval])) by (reason)", "hide": true, "instant": false, "legendFormat": "__auto", "range": true, "refId": "A" } ], "title": "Rejected Connections", "type": "timeseries" }, { "collapsed": false, "gridPos": { "h": 1, "w": 24, "x": 0, "y": 9 }, "id": 6, "panels": [], "title": "Connection", "type": "row" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "thresholds" }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] } }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 0, "y": 10 }, "id": 1, "options": { "colorMode": "value", "graphMode": "area", "justifyMode": "auto", "orientation": "auto", "reduceOptions": { "calcs": [ "lastNotNull" ], "fields": "", "values": false }, "showPercentChange": false, "textMode": "auto", "wideLayout": true }, "pluginVersion": "10.2.3", "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "sum (quicgo_connections_started_total{instance=~\"$instance\"}) by (dir) - sum (quicgo_connections_closed_total{instance=~\"$instance\"}) by (dir)", "instant": false, "legendFormat": "{{dir}}", "range": true, "refId": "A" } ], "title": "Currently Active Connections", "type": "stat" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] }, "unit": "s" }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 12, "y": 10 }, "id": 5, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "multi", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "histogram_quantile(0.5, sum(rate(quicgo_handshake_duration_seconds_bucket{instance=~\"$instance\"}[$__rate_interval])) by (le))", "instant": false, "legendFormat": "50th percentile", "range": true, "refId": "A" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "histogram_quantile(0.9, sum(rate(quicgo_handshake_duration_seconds_bucket{instance=~\"$instance\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "90th percentile", "range": true, "refId": "B" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "histogram_quantile(0.95, sum(rate(quicgo_handshake_duration_seconds_bucket{instance=~\"$instance\"}[$__rate_interval])) by (le))", "hide": false, "instant": false, "legendFormat": "95th percentile", "range": true, "refId": "C" } ], "title": "Handshake Duration", "type": "timeseries" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] } }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 0, "y": 18 }, "id": 11, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "single", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "sum(rate(quicgo_connections_closed_total{instance=~\"$instance\"}[$__rate_interval])) by (reason)", "instant": false, "legendFormat": "__auto", "range": true, "refId": "A" } ], "title": "Close Reason", "type": "timeseries" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] }, "unit": "s" }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 12, "y": 18 }, "id": 2, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "multi", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "histogram_quantile(0.5, sum(rate(quicgo_connection_duration_seconds_bucket{instance=~\"$instance\"}[$__rate_interval])) by (le))\n", "hide": false, "instant": false, "legendFormat": "50th percentile", "range": true, "refId": "A" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "histogram_quantile(0.9, sum(rate(quicgo_connection_duration_seconds_bucket{instance=~\"$instance\"}[$__rate_interval])) by (le))\n", "hide": false, "instant": false, "legendFormat": "90th percentile", "range": true, "refId": "B" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "histogram_quantile(0.95, sum(rate(quicgo_connection_duration_seconds_bucket{instance=~\"$instance\"}[$__rate_interval])) by (le))\n", "hide": false, "instant": false, "legendFormat": "95th percentile", "range": true, "refId": "C" } ], "title": "Connection Durations", "type": "timeseries" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] } }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 0, "y": 26 }, "id": 13, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "single", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "code", "expr": "sum(rate(quicgo_packets_received_total{instance=~\"$instance\"}[$__rate_interval])) by (type)", "fullMetaSearch": false, "includeNullMetadata": false, "instant": false, "legendFormat": "{{type}}", "range": true, "refId": "A", "useBackend": false } ], "title": "Packets Received", "type": "timeseries" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "hideFrom": { "legend": false, "tooltip": false, "viz": false }, "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "auto", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "red", "value": 80 } ] } }, "overrides": [] }, "gridPos": { "h": 8, "w": 12, "x": 12, "y": 26 }, "id": 15, "options": { "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, "tooltip": { "mode": "single", "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "code", "expr": "sum(rate(quicgo_packets_sent_total{instance=~\"$instance\"}[$__rate_interval])) by (type)", "fullMetaSearch": false, "includeNullMetadata": false, "instant": false, "legendFormat": "{{type}}", "range": true, "refId": "A", "useBackend": false } ], "title": "Packets Sent", "type": "timeseries" } ], "refresh": "", "schemaVersion": 39, "tags": [], "templating": { "list": [ { "current": {}, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "definition": "label_values(up,instance)", "hide": 0, "includeAll": true, "multi": true, "name": "instance", "options": [], "query": { "qryType": 1, "query": "label_values(up,instance)", "refId": "PrometheusVariableQueryEditor-VariableQuery" }, "refresh": 1, "regex": "", "skipUrlSync": false, "sort": 0, "type": "query" } ] }, "time": { "from": "now-30m", "to": "now" }, "timepicker": {}, "timezone": "", "title": "quic-go", "uid": "afd27180-618a-42ab-99fd-0508776d9c29", "version": 17, "weekStart": "" } quic-go-0.59.0/mock_ack_frame_source_test.go000066400000000000000000000054641513066070600210550ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: AckFrameSource) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" wire "github.com/quic-go/quic-go/internal/wire" gomock "go.uber.org/mock/gomock" ) // MockAckFrameSource is a mock of AckFrameSource interface. type MockAckFrameSource struct { ctrl *gomock.Controller recorder *MockAckFrameSourceMockRecorder isgomock struct{} } // MockAckFrameSourceMockRecorder is the mock recorder for MockAckFrameSource. type MockAckFrameSourceMockRecorder struct { mock *MockAckFrameSource } // NewMockAckFrameSource creates a new mock instance. func NewMockAckFrameSource(ctrl *gomock.Controller) *MockAckFrameSource { mock := &MockAckFrameSource{ctrl: ctrl} mock.recorder = &MockAckFrameSourceMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { return m.recorder } // GetAckFrame mocks base method. func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel, now monotime.Time, onlyIfQueued bool) *wire.AckFrame { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAckFrame", arg0, now, onlyIfQueued) ret0, _ := ret[0].(*wire.AckFrame) return ret0 } // GetAckFrame indicates an expected call of GetAckFrame. func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, now, onlyIfQueued any) *MockAckFrameSourceGetAckFrameCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0, now, onlyIfQueued) return &MockAckFrameSourceGetAckFrameCall{Call: call} } // MockAckFrameSourceGetAckFrameCall wrap *gomock.Call type MockAckFrameSourceGetAckFrameCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockAckFrameSourceGetAckFrameCall) Return(arg0 *wire.AckFrame) *MockAckFrameSourceGetAckFrameCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockAckFrameSourceGetAckFrameCall) Do(f func(protocol.EncryptionLevel, monotime.Time, bool) *wire.AckFrame) *MockAckFrameSourceGetAckFrameCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockAckFrameSourceGetAckFrameCall) DoAndReturn(f func(protocol.EncryptionLevel, monotime.Time, bool) *wire.AckFrame) *MockAckFrameSourceGetAckFrameCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_conn_runner_test.go000066400000000000000000000164211513066070600201060ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: ConnRunner) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" time "time" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockConnRunner is a mock of ConnRunner interface. type MockConnRunner struct { ctrl *gomock.Controller recorder *MockConnRunnerMockRecorder isgomock struct{} } // MockConnRunnerMockRecorder is the mock recorder for MockConnRunner. type MockConnRunnerMockRecorder struct { mock *MockConnRunner } // NewMockConnRunner creates a new mock instance. func NewMockConnRunner(ctrl *gomock.Controller) *MockConnRunner { mock := &MockConnRunner{ctrl: ctrl} mock.recorder = &MockConnRunnerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockConnRunner) EXPECT() *MockConnRunnerMockRecorder { return m.recorder } // Add mocks base method. func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Add", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } // Add indicates an expected call of Add. func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 any) *MockConnRunnerAddCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) return &MockConnRunnerAddCall{Call: call} } // MockConnRunnerAddCall wrap *gomock.Call type MockConnRunnerAddCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockConnRunnerAddCall) Return(arg0 bool) *MockConnRunnerAddCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockConnRunnerAddCall) Do(f func(protocol.ConnectionID, packetHandler) bool) *MockConnRunnerAddCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockConnRunnerAddCall) DoAndReturn(f func(protocol.ConnectionID, packetHandler) bool) *MockConnRunnerAddCall { c.Call = c.Call.DoAndReturn(f) return c } // AddResetToken mocks base method. func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { m.ctrl.T.Helper() m.ctrl.Call(m, "AddResetToken", arg0, arg1) } // AddResetToken indicates an expected call of AddResetToken. func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 any) *MockConnRunnerAddResetTokenCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) return &MockConnRunnerAddResetTokenCall{Call: call} } // MockConnRunnerAddResetTokenCall wrap *gomock.Call type MockConnRunnerAddResetTokenCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockConnRunnerAddResetTokenCall) Return() *MockConnRunnerAddResetTokenCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockConnRunnerAddResetTokenCall) Do(f func(protocol.StatelessResetToken, packetHandler)) *MockConnRunnerAddResetTokenCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockConnRunnerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken, packetHandler)) *MockConnRunnerAddResetTokenCall { c.Call = c.Call.DoAndReturn(f) return c } // Remove mocks base method. func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() m.ctrl.Call(m, "Remove", arg0) } // Remove indicates an expected call of Remove. func (mr *MockConnRunnerMockRecorder) Remove(arg0 any) *MockConnRunnerRemoveCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) return &MockConnRunnerRemoveCall{Call: call} } // MockConnRunnerRemoveCall wrap *gomock.Call type MockConnRunnerRemoveCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockConnRunnerRemoveCall) Return() *MockConnRunnerRemoveCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockConnRunnerRemoveCall) Do(f func(protocol.ConnectionID)) *MockConnRunnerRemoveCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockConnRunnerRemoveCall) DoAndReturn(f func(protocol.ConnectionID)) *MockConnRunnerRemoveCall { c.Call = c.Call.DoAndReturn(f) return c } // RemoveResetToken mocks base method. func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { m.ctrl.T.Helper() m.ctrl.Call(m, "RemoveResetToken", arg0) } // RemoveResetToken indicates an expected call of RemoveResetToken. func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 any) *MockConnRunnerRemoveResetTokenCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) return &MockConnRunnerRemoveResetTokenCall{Call: call} } // MockConnRunnerRemoveResetTokenCall wrap *gomock.Call type MockConnRunnerRemoveResetTokenCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockConnRunnerRemoveResetTokenCall) Return() *MockConnRunnerRemoveResetTokenCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockConnRunnerRemoveResetTokenCall) Do(f func(protocol.StatelessResetToken)) *MockConnRunnerRemoveResetTokenCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockConnRunnerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken)) *MockConnRunnerRemoveResetTokenCall { c.Call = c.Call.DoAndReturn(f) return c } // ReplaceWithClosed mocks base method. func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte, arg2 time.Duration) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *MockConnRunnerReplaceWithClosedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2) return &MockConnRunnerReplaceWithClosedCall{Call: call} } // MockConnRunnerReplaceWithClosedCall wrap *gomock.Call type MockConnRunnerReplaceWithClosedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockConnRunnerReplaceWithClosedCall) Return() *MockConnRunnerReplaceWithClosedCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockConnRunnerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockConnRunnerReplaceWithClosedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte, time.Duration)) *MockConnRunnerReplaceWithClosedCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_frame_source_test.go000066400000000000000000000103131513066070600202240ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: FrameSource) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockFrameSource is a mock of FrameSource interface. type MockFrameSource struct { ctrl *gomock.Controller recorder *MockFrameSourceMockRecorder isgomock struct{} } // MockFrameSourceMockRecorder is the mock recorder for MockFrameSource. type MockFrameSourceMockRecorder struct { mock *MockFrameSource } // NewMockFrameSource creates a new mock instance. func NewMockFrameSource(ctrl *gomock.Controller) *MockFrameSource { mock := &MockFrameSource{ctrl: ctrl} mock.recorder = &MockFrameSourceMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockFrameSource) EXPECT() *MockFrameSourceMockRecorder { return m.recorder } // Append mocks base method. func (m *MockFrameSource) Append(arg0 []ackhandler.Frame, arg1 []ackhandler.StreamFrame, arg2 protocol.ByteCount, arg3 monotime.Time, arg4 protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Append", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([]ackhandler.Frame) ret1, _ := ret[1].([]ackhandler.StreamFrame) ret2, _ := ret[2].(protocol.ByteCount) return ret0, ret1, ret2 } // Append indicates an expected call of Append. func (mr *MockFrameSourceMockRecorder) Append(arg0, arg1, arg2, arg3, arg4 any) *MockFrameSourceAppendCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Append", reflect.TypeOf((*MockFrameSource)(nil).Append), arg0, arg1, arg2, arg3, arg4) return &MockFrameSourceAppendCall{Call: call} } // MockFrameSourceAppendCall wrap *gomock.Call type MockFrameSourceAppendCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockFrameSourceAppendCall) Return(arg0 []ackhandler.Frame, arg1 []ackhandler.StreamFrame, arg2 protocol.ByteCount) *MockFrameSourceAppendCall { c.Call = c.Call.Return(arg0, arg1, arg2) return c } // Do rewrite *gomock.Call.Do func (c *MockFrameSourceAppendCall) Do(f func([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount, monotime.Time, protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount)) *MockFrameSourceAppendCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockFrameSourceAppendCall) DoAndReturn(f func([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount, monotime.Time, protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount)) *MockFrameSourceAppendCall { c.Call = c.Call.DoAndReturn(f) return c } // HasData mocks base method. func (m *MockFrameSource) HasData() bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HasData") ret0, _ := ret[0].(bool) return ret0 } // HasData indicates an expected call of HasData. func (mr *MockFrameSourceMockRecorder) HasData() *MockFrameSourceHasDataCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockFrameSource)(nil).HasData)) return &MockFrameSourceHasDataCall{Call: call} } // MockFrameSourceHasDataCall wrap *gomock.Call type MockFrameSourceHasDataCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockFrameSourceHasDataCall) Return(arg0 bool) *MockFrameSourceHasDataCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockFrameSourceHasDataCall) Do(f func() bool) *MockFrameSourceHasDataCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockFrameSourceHasDataCall) DoAndReturn(f func() bool) *MockFrameSourceHasDataCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_mtu_discoverer_test.go000066400000000000000000000167601513066070600206200ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: MTUDiscoverer) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockMTUDiscoverer is a mock of MTUDiscoverer interface. type MockMTUDiscoverer struct { ctrl *gomock.Controller recorder *MockMTUDiscovererMockRecorder isgomock struct{} } // MockMTUDiscovererMockRecorder is the mock recorder for MockMTUDiscoverer. type MockMTUDiscovererMockRecorder struct { mock *MockMTUDiscoverer } // NewMockMTUDiscoverer creates a new mock instance. func NewMockMTUDiscoverer(ctrl *gomock.Controller) *MockMTUDiscoverer { mock := &MockMTUDiscoverer{ctrl: ctrl} mock.recorder = &MockMTUDiscovererMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockMTUDiscoverer) EXPECT() *MockMTUDiscovererMockRecorder { return m.recorder } // CurrentSize mocks base method. func (m *MockMTUDiscoverer) CurrentSize() protocol.ByteCount { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CurrentSize") ret0, _ := ret[0].(protocol.ByteCount) return ret0 } // CurrentSize indicates an expected call of CurrentSize. func (mr *MockMTUDiscovererMockRecorder) CurrentSize() *MockMTUDiscovererCurrentSizeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CurrentSize", reflect.TypeOf((*MockMTUDiscoverer)(nil).CurrentSize)) return &MockMTUDiscovererCurrentSizeCall{Call: call} } // MockMTUDiscovererCurrentSizeCall wrap *gomock.Call type MockMTUDiscovererCurrentSizeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockMTUDiscovererCurrentSizeCall) Return(arg0 protocol.ByteCount) *MockMTUDiscovererCurrentSizeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockMTUDiscovererCurrentSizeCall) Do(f func() protocol.ByteCount) *MockMTUDiscovererCurrentSizeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockMTUDiscovererCurrentSizeCall) DoAndReturn(f func() protocol.ByteCount) *MockMTUDiscovererCurrentSizeCall { c.Call = c.Call.DoAndReturn(f) return c } // GetPing mocks base method. func (m *MockMTUDiscoverer) GetPing(now monotime.Time) (ackhandler.Frame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPing", now) ret0, _ := ret[0].(ackhandler.Frame) ret1, _ := ret[1].(protocol.ByteCount) return ret0, ret1 } // GetPing indicates an expected call of GetPing. func (mr *MockMTUDiscovererMockRecorder) GetPing(now any) *MockMTUDiscovererGetPingCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMTUDiscoverer)(nil).GetPing), now) return &MockMTUDiscovererGetPingCall{Call: call} } // MockMTUDiscovererGetPingCall wrap *gomock.Call type MockMTUDiscovererGetPingCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockMTUDiscovererGetPingCall) Return(ping ackhandler.Frame, datagramSize protocol.ByteCount) *MockMTUDiscovererGetPingCall { c.Call = c.Call.Return(ping, datagramSize) return c } // Do rewrite *gomock.Call.Do func (c *MockMTUDiscovererGetPingCall) Do(f func(monotime.Time) (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockMTUDiscovererGetPingCall) DoAndReturn(f func(monotime.Time) (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { c.Call = c.Call.DoAndReturn(f) return c } // Reset mocks base method. func (m *MockMTUDiscoverer) Reset(now monotime.Time, start, max protocol.ByteCount) { m.ctrl.T.Helper() m.ctrl.Call(m, "Reset", now, start, max) } // Reset indicates an expected call of Reset. func (mr *MockMTUDiscovererMockRecorder) Reset(now, start, max any) *MockMTUDiscovererResetCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockMTUDiscoverer)(nil).Reset), now, start, max) return &MockMTUDiscovererResetCall{Call: call} } // MockMTUDiscovererResetCall wrap *gomock.Call type MockMTUDiscovererResetCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockMTUDiscovererResetCall) Return() *MockMTUDiscovererResetCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockMTUDiscovererResetCall) Do(f func(monotime.Time, protocol.ByteCount, protocol.ByteCount)) *MockMTUDiscovererResetCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockMTUDiscovererResetCall) DoAndReturn(f func(monotime.Time, protocol.ByteCount, protocol.ByteCount)) *MockMTUDiscovererResetCall { c.Call = c.Call.DoAndReturn(f) return c } // ShouldSendProbe mocks base method. func (m *MockMTUDiscoverer) ShouldSendProbe(now monotime.Time) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ShouldSendProbe", now) ret0, _ := ret[0].(bool) return ret0 } // ShouldSendProbe indicates an expected call of ShouldSendProbe. func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(now any) *MockMTUDiscovererShouldSendProbeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMTUDiscoverer)(nil).ShouldSendProbe), now) return &MockMTUDiscovererShouldSendProbeCall{Call: call} } // MockMTUDiscovererShouldSendProbeCall wrap *gomock.Call type MockMTUDiscovererShouldSendProbeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockMTUDiscovererShouldSendProbeCall) Return(arg0 bool) *MockMTUDiscovererShouldSendProbeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockMTUDiscovererShouldSendProbeCall) Do(f func(monotime.Time) bool) *MockMTUDiscovererShouldSendProbeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockMTUDiscovererShouldSendProbeCall) DoAndReturn(f func(monotime.Time) bool) *MockMTUDiscovererShouldSendProbeCall { c.Call = c.Call.DoAndReturn(f) return c } // Start mocks base method. func (m *MockMTUDiscoverer) Start(now monotime.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "Start", now) } // Start indicates an expected call of Start. func (mr *MockMTUDiscovererMockRecorder) Start(now any) *MockMTUDiscovererStartCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start), now) return &MockMTUDiscovererStartCall{Call: call} } // MockMTUDiscovererStartCall wrap *gomock.Call type MockMTUDiscovererStartCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockMTUDiscovererStartCall) Return() *MockMTUDiscovererStartCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockMTUDiscovererStartCall) Do(f func(monotime.Time)) *MockMTUDiscovererStartCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockMTUDiscovererStartCall) DoAndReturn(f func(monotime.Time)) *MockMTUDiscovererStartCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_packer_test.go000066400000000000000000000366661513066070600170420ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Packer) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" qerr "github.com/quic-go/quic-go/internal/qerr" gomock "go.uber.org/mock/gomock" ) // MockPacker is a mock of Packer interface. type MockPacker struct { ctrl *gomock.Controller recorder *MockPackerMockRecorder isgomock struct{} } // MockPackerMockRecorder is the mock recorder for MockPacker. type MockPackerMockRecorder struct { mock *MockPacker } // NewMockPacker creates a new mock instance. func NewMockPacker(ctrl *gomock.Controller) *MockPacker { mock := &MockPacker{ctrl: ctrl} mock.recorder = &MockPackerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockPacker) EXPECT() *MockPackerMockRecorder { return m.recorder } // AppendPacket mocks base method. func (m *MockPacker) AppendPacket(arg0 *packetBuffer, maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version) (shortHeaderPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendPacket", arg0, maxPacketSize, now, v) ret0, _ := ret[0].(shortHeaderPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // AppendPacket indicates an expected call of AppendPacket. func (mr *MockPackerMockRecorder) AppendPacket(arg0, maxPacketSize, now, v any) *MockPackerAppendPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendPacket", reflect.TypeOf((*MockPacker)(nil).AppendPacket), arg0, maxPacketSize, now, v) return &MockPackerAppendPacketCall{Call: call} } // MockPackerAppendPacketCall wrap *gomock.Call type MockPackerAppendPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerAppendPacketCall) Return(arg0 shortHeaderPacket, arg1 error) *MockPackerAppendPacketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerAppendPacketCall) Do(f func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error)) *MockPackerAppendPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerAppendPacketCall) DoAndReturn(f func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error)) *MockPackerAppendPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // PackAckOnlyPacket mocks base method. func (m *MockPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackAckOnlyPacket", maxPacketSize, now, v) ret0, _ := ret[0].(shortHeaderPacket) ret1, _ := ret[1].(*packetBuffer) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // PackAckOnlyPacket indicates an expected call of PackAckOnlyPacket. func (mr *MockPackerMockRecorder) PackAckOnlyPacket(maxPacketSize, now, v any) *MockPackerPackAckOnlyPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAckOnlyPacket", reflect.TypeOf((*MockPacker)(nil).PackAckOnlyPacket), maxPacketSize, now, v) return &MockPackerPackAckOnlyPacketCall{Call: call} } // MockPackerPackAckOnlyPacketCall wrap *gomock.Call type MockPackerPackAckOnlyPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackAckOnlyPacketCall) Return(arg0 shortHeaderPacket, arg1 *packetBuffer, arg2 error) *MockPackerPackAckOnlyPacketCall { c.Call = c.Call.Return(arg0, arg1, arg2) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackAckOnlyPacketCall) Do(f func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackAckOnlyPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackAckOnlyPacketCall) DoAndReturn(f func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackAckOnlyPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // PackApplicationClose mocks base method. func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 protocol.ByteCount, arg2 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackApplicationClose", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackApplicationClose indicates an expected call of PackApplicationClose. func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 any) *MockPackerPackApplicationCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1, arg2) return &MockPackerPackApplicationCloseCall{Call: call} } // MockPackerPackApplicationCloseCall wrap *gomock.Call type MockPackerPackApplicationCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackApplicationCloseCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackApplicationCloseCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackApplicationCloseCall) Do(f func(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackApplicationCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackApplicationCloseCall) DoAndReturn(f func(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackApplicationCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // PackCoalescedPacket mocks base method. func (m *MockPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackCoalescedPacket", onlyAck, maxPacketSize, now, v) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket. func (mr *MockPackerMockRecorder) PackCoalescedPacket(onlyAck, maxPacketSize, now, v any) *MockPackerPackCoalescedPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), onlyAck, maxPacketSize, now, v) return &MockPackerPackCoalescedPacketCall{Call: call} } // MockPackerPackCoalescedPacketCall wrap *gomock.Call type MockPackerPackCoalescedPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackCoalescedPacketCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackCoalescedPacketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackCoalescedPacketCall) Do(f func(bool, protocol.ByteCount, monotime.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerPackCoalescedPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackCoalescedPacketCall) DoAndReturn(f func(bool, protocol.ByteCount, monotime.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerPackCoalescedPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // PackConnectionClose mocks base method. func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protocol.ByteCount, arg2 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackConnectionClose", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackConnectionClose indicates an expected call of PackConnectionClose. func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 any) *MockPackerPackConnectionCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1, arg2) return &MockPackerPackConnectionCloseCall{Call: call} } // MockPackerPackConnectionCloseCall wrap *gomock.Call type MockPackerPackConnectionCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackConnectionCloseCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackConnectionCloseCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackConnectionCloseCall) Do(f func(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackConnectionCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackConnectionCloseCall) DoAndReturn(f func(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackConnectionCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // PackMTUProbePacket mocks base method. func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size, v) ret0, _ := ret[0].(shortHeaderPacket) ret1, _ := ret[1].(*packetBuffer) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // PackMTUProbePacket indicates an expected call of PackMTUProbePacket. func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size, v any) *MockPackerPackMTUProbePacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size, v) return &MockPackerPackMTUProbePacketCall{Call: call} } // MockPackerPackMTUProbePacketCall wrap *gomock.Call type MockPackerPackMTUProbePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackMTUProbePacketCall) Return(arg0 shortHeaderPacket, arg1 *packetBuffer, arg2 error) *MockPackerPackMTUProbePacketCall { c.Call = c.Call.Return(arg0, arg1, arg2) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackMTUProbePacketCall) Do(f func(ackhandler.Frame, protocol.ByteCount, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackMTUProbePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackMTUProbePacketCall) DoAndReturn(f func(ackhandler.Frame, protocol.ByteCount, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackMTUProbePacketCall { c.Call = c.Call.DoAndReturn(f) return c } // PackPTOProbePacket mocks base method. func (m *MockPacker) PackPTOProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, addPingIfEmpty bool, now monotime.Time, v protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackPTOProbePacket", arg0, arg1, addPingIfEmpty, now, v) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackPTOProbePacket indicates an expected call of PackPTOProbePacket. func (mr *MockPackerMockRecorder) PackPTOProbePacket(arg0, arg1, addPingIfEmpty, now, v any) *MockPackerPackPTOProbePacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPTOProbePacket", reflect.TypeOf((*MockPacker)(nil).PackPTOProbePacket), arg0, arg1, addPingIfEmpty, now, v) return &MockPackerPackPTOProbePacketCall{Call: call} } // MockPackerPackPTOProbePacketCall wrap *gomock.Call type MockPackerPackPTOProbePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackPTOProbePacketCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackPTOProbePacketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackPTOProbePacketCall) Do(f func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerPackPTOProbePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackPTOProbePacketCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error)) *MockPackerPackPTOProbePacketCall { c.Call = c.Call.DoAndReturn(f) return c } // PackPathProbePacket mocks base method. func (m *MockPacker) PackPathProbePacket(arg0 protocol.ConnectionID, arg1 []ackhandler.Frame, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackPathProbePacket", arg0, arg1, arg2) ret0, _ := ret[0].(shortHeaderPacket) ret1, _ := ret[1].(*packetBuffer) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // PackPathProbePacket indicates an expected call of PackPathProbePacket. func (mr *MockPackerMockRecorder) PackPathProbePacket(arg0, arg1, arg2 any) *MockPackerPackPathProbePacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPathProbePacket", reflect.TypeOf((*MockPacker)(nil).PackPathProbePacket), arg0, arg1, arg2) return &MockPackerPackPathProbePacketCall{Call: call} } // MockPackerPackPathProbePacketCall wrap *gomock.Call type MockPackerPackPathProbePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerPackPathProbePacketCall) Return(arg0 shortHeaderPacket, arg1 *packetBuffer, arg2 error) *MockPackerPackPathProbePacketCall { c.Call = c.Call.Return(arg0, arg1, arg2) return c } // Do rewrite *gomock.Call.Do func (c *MockPackerPackPathProbePacketCall) Do(f func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerPackPathProbePacketCall) DoAndReturn(f func(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackPathProbePacketCall { c.Call = c.Call.DoAndReturn(f) return c } // SetToken mocks base method. func (m *MockPacker) SetToken(arg0 []byte) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetToken", arg0) } // SetToken indicates an expected call of SetToken. func (mr *MockPackerMockRecorder) SetToken(arg0 any) *MockPackerSetTokenCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) return &MockPackerSetTokenCall{Call: call} } // MockPackerSetTokenCall wrap *gomock.Call type MockPackerSetTokenCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPackerSetTokenCall) Return() *MockPackerSetTokenCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockPackerSetTokenCall) Do(f func([]byte)) *MockPackerSetTokenCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPackerSetTokenCall) DoAndReturn(f func([]byte)) *MockPackerSetTokenCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_packet_handler_test.go000066400000000000000000000115551513066070600205270ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: PacketHandler) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" qerr "github.com/quic-go/quic-go/internal/qerr" gomock "go.uber.org/mock/gomock" ) // MockPacketHandler is a mock of PacketHandler interface. type MockPacketHandler struct { ctrl *gomock.Controller recorder *MockPacketHandlerMockRecorder isgomock struct{} } // MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler. type MockPacketHandlerMockRecorder struct { mock *MockPacketHandler } // NewMockPacketHandler creates a new mock instance. func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler { mock := &MockPacketHandler{ctrl: ctrl} mock.recorder = &MockPacketHandlerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { return m.recorder } // closeWithTransportError mocks base method. func (m *MockPacketHandler) closeWithTransportError(arg0 qerr.TransportErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "closeWithTransportError", arg0) } // closeWithTransportError indicates an expected call of closeWithTransportError. func (mr *MockPacketHandlerMockRecorder) closeWithTransportError(arg0 any) *MockPacketHandlercloseWithTransportErrorCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockPacketHandler)(nil).closeWithTransportError), arg0) return &MockPacketHandlercloseWithTransportErrorCall{Call: call} } // MockPacketHandlercloseWithTransportErrorCall wrap *gomock.Call type MockPacketHandlercloseWithTransportErrorCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketHandlercloseWithTransportErrorCall) Return() *MockPacketHandlercloseWithTransportErrorCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockPacketHandlercloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *MockPacketHandlercloseWithTransportErrorCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketHandlercloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *MockPacketHandlercloseWithTransportErrorCall { c.Call = c.Call.DoAndReturn(f) return c } // destroy mocks base method. func (m *MockPacketHandler) destroy(arg0 error) { m.ctrl.T.Helper() m.ctrl.Call(m, "destroy", arg0) } // destroy indicates an expected call of destroy. func (mr *MockPacketHandlerMockRecorder) destroy(arg0 any) *MockPacketHandlerdestroyCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) return &MockPacketHandlerdestroyCall{Call: call} } // MockPacketHandlerdestroyCall wrap *gomock.Call type MockPacketHandlerdestroyCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketHandlerdestroyCall) Return() *MockPacketHandlerdestroyCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockPacketHandlerdestroyCall) Do(f func(error)) *MockPacketHandlerdestroyCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketHandlerdestroyCall) DoAndReturn(f func(error)) *MockPacketHandlerdestroyCall { c.Call = c.Call.DoAndReturn(f) return c } // handlePacket mocks base method. func (m *MockPacketHandler) handlePacket(arg0 receivedPacket) { m.ctrl.T.Helper() m.ctrl.Call(m, "handlePacket", arg0) } // handlePacket indicates an expected call of handlePacket. func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 any) *MockPacketHandlerhandlePacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) return &MockPacketHandlerhandlePacketCall{Call: call} } // MockPacketHandlerhandlePacketCall wrap *gomock.Call type MockPacketHandlerhandlePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketHandlerhandlePacketCall) Return() *MockPacketHandlerhandlePacketCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockPacketHandlerhandlePacketCall) Do(f func(receivedPacket)) *MockPacketHandlerhandlePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketHandlerhandlePacketCall) DoAndReturn(f func(receivedPacket)) *MockPacketHandlerhandlePacketCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_packetconn_test.go000066400000000000000000000226241513066070600177070ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: net (interfaces: PacketConn) // // Generated by this command: // // mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn // // Package quic is a generated GoMock package. package quic import ( net "net" reflect "reflect" time "time" gomock "go.uber.org/mock/gomock" ) // MockPacketConn is a mock of PacketConn interface. type MockPacketConn struct { ctrl *gomock.Controller recorder *MockPacketConnMockRecorder isgomock struct{} } // MockPacketConnMockRecorder is the mock recorder for MockPacketConn. type MockPacketConnMockRecorder struct { mock *MockPacketConn } // NewMockPacketConn creates a new mock instance. func NewMockPacketConn(ctrl *gomock.Controller) *MockPacketConn { mock := &MockPacketConn{ctrl: ctrl} mock.recorder = &MockPacketConnMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockPacketConn) EXPECT() *MockPacketConnMockRecorder { return m.recorder } // Close mocks base method. func (m *MockPacketConn) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. func (mr *MockPacketConnMockRecorder) Close() *MockPacketConnCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close)) return &MockPacketConnCloseCall{Call: call} } // MockPacketConnCloseCall wrap *gomock.Call type MockPacketConnCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnCloseCall) Return(arg0 error) *MockPacketConnCloseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnCloseCall) Do(f func() error) *MockPacketConnCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnCloseCall) DoAndReturn(f func() error) *MockPacketConnCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // LocalAddr mocks base method. func (m *MockPacketConn) LocalAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LocalAddr") ret0, _ := ret[0].(net.Addr) return ret0 } // LocalAddr indicates an expected call of LocalAddr. func (mr *MockPacketConnMockRecorder) LocalAddr() *MockPacketConnLocalAddrCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr)) return &MockPacketConnLocalAddrCall{Call: call} } // MockPacketConnLocalAddrCall wrap *gomock.Call type MockPacketConnLocalAddrCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnLocalAddrCall) Return(arg0 net.Addr) *MockPacketConnLocalAddrCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnLocalAddrCall) Do(f func() net.Addr) *MockPacketConnLocalAddrCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockPacketConnLocalAddrCall { c.Call = c.Call.DoAndReturn(f) return c } // ReadFrom mocks base method. func (m *MockPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReadFrom", p) ret0, _ := ret[0].(int) ret1, _ := ret[1].(net.Addr) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // ReadFrom indicates an expected call of ReadFrom. func (mr *MockPacketConnMockRecorder) ReadFrom(p any) *MockPacketConnReadFromCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), p) return &MockPacketConnReadFromCall{Call: call} } // MockPacketConnReadFromCall wrap *gomock.Call type MockPacketConnReadFromCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnReadFromCall) Return(n int, addr net.Addr, err error) *MockPacketConnReadFromCall { c.Call = c.Call.Return(n, addr, err) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnReadFromCall) Do(f func([]byte) (int, net.Addr, error)) *MockPacketConnReadFromCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnReadFromCall) DoAndReturn(f func([]byte) (int, net.Addr, error)) *MockPacketConnReadFromCall { c.Call = c.Call.DoAndReturn(f) return c } // SetDeadline mocks base method. func (m *MockPacketConn) SetDeadline(t time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetDeadline", t) ret0, _ := ret[0].(error) return ret0 } // SetDeadline indicates an expected call of SetDeadline. func (mr *MockPacketConnMockRecorder) SetDeadline(t any) *MockPacketConnSetDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), t) return &MockPacketConnSetDeadlineCall{Call: call} } // MockPacketConnSetDeadlineCall wrap *gomock.Call type MockPacketConnSetDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnSetDeadlineCall) Return(arg0 error) *MockPacketConnSetDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnSetDeadlineCall) Do(f func(time.Time) error) *MockPacketConnSetDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnSetDeadlineCall) DoAndReturn(f func(time.Time) error) *MockPacketConnSetDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // SetReadDeadline mocks base method. func (m *MockPacketConn) SetReadDeadline(t time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetReadDeadline", t) ret0, _ := ret[0].(error) return ret0 } // SetReadDeadline indicates an expected call of SetReadDeadline. func (mr *MockPacketConnMockRecorder) SetReadDeadline(t any) *MockPacketConnSetReadDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), t) return &MockPacketConnSetReadDeadlineCall{Call: call} } // MockPacketConnSetReadDeadlineCall wrap *gomock.Call type MockPacketConnSetReadDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnSetReadDeadlineCall) Return(arg0 error) *MockPacketConnSetReadDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnSetReadDeadlineCall) Do(f func(time.Time) error) *MockPacketConnSetReadDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnSetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockPacketConnSetReadDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // SetWriteDeadline mocks base method. func (m *MockPacketConn) SetWriteDeadline(t time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetWriteDeadline", t) ret0, _ := ret[0].(error) return ret0 } // SetWriteDeadline indicates an expected call of SetWriteDeadline. func (mr *MockPacketConnMockRecorder) SetWriteDeadline(t any) *MockPacketConnSetWriteDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), t) return &MockPacketConnSetWriteDeadlineCall{Call: call} } // MockPacketConnSetWriteDeadlineCall wrap *gomock.Call type MockPacketConnSetWriteDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnSetWriteDeadlineCall) Return(arg0 error) *MockPacketConnSetWriteDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnSetWriteDeadlineCall) Do(f func(time.Time) error) *MockPacketConnSetWriteDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnSetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockPacketConnSetWriteDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // WriteTo mocks base method. func (m *MockPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WriteTo", p, addr) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WriteTo indicates an expected call of WriteTo. func (mr *MockPacketConnMockRecorder) WriteTo(p, addr any) *MockPacketConnWriteToCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), p, addr) return &MockPacketConnWriteToCall{Call: call} } // MockPacketConnWriteToCall wrap *gomock.Call type MockPacketConnWriteToCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockPacketConnWriteToCall) Return(n int, err error) *MockPacketConnWriteToCall { c.Call = c.Call.Return(n, err) return c } // Do rewrite *gomock.Call.Do func (c *MockPacketConnWriteToCall) Do(f func([]byte, net.Addr) (int, error)) *MockPacketConnWriteToCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockPacketConnWriteToCall) DoAndReturn(f func([]byte, net.Addr) (int, error)) *MockPacketConnWriteToCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_raw_conn_test.go000066400000000000000000000203651513066070600173700ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: RawConn) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn // // Package quic is a generated GoMock package. package quic import ( net "net" reflect "reflect" time "time" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockRawConn is a mock of RawConn interface. type MockRawConn struct { ctrl *gomock.Controller recorder *MockRawConnMockRecorder isgomock struct{} } // MockRawConnMockRecorder is the mock recorder for MockRawConn. type MockRawConnMockRecorder struct { mock *MockRawConn } // NewMockRawConn creates a new mock instance. func NewMockRawConn(ctrl *gomock.Controller) *MockRawConn { mock := &MockRawConn{ctrl: ctrl} mock.recorder = &MockRawConnMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockRawConn) EXPECT() *MockRawConnMockRecorder { return m.recorder } // Close mocks base method. func (m *MockRawConn) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. func (mr *MockRawConnMockRecorder) Close() *MockRawConnCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRawConn)(nil).Close)) return &MockRawConnCloseCall{Call: call} } // MockRawConnCloseCall wrap *gomock.Call type MockRawConnCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockRawConnCloseCall) Return(arg0 error) *MockRawConnCloseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockRawConnCloseCall) Do(f func() error) *MockRawConnCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockRawConnCloseCall) DoAndReturn(f func() error) *MockRawConnCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // LocalAddr mocks base method. func (m *MockRawConn) LocalAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LocalAddr") ret0, _ := ret[0].(net.Addr) return ret0 } // LocalAddr indicates an expected call of LocalAddr. func (mr *MockRawConnMockRecorder) LocalAddr() *MockRawConnLocalAddrCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockRawConn)(nil).LocalAddr)) return &MockRawConnLocalAddrCall{Call: call} } // MockRawConnLocalAddrCall wrap *gomock.Call type MockRawConnLocalAddrCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockRawConnLocalAddrCall) Return(arg0 net.Addr) *MockRawConnLocalAddrCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockRawConnLocalAddrCall) Do(f func() net.Addr) *MockRawConnLocalAddrCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockRawConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockRawConnLocalAddrCall { c.Call = c.Call.DoAndReturn(f) return c } // ReadPacket mocks base method. func (m *MockRawConn) ReadPacket() (receivedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReadPacket") ret0, _ := ret[0].(receivedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // ReadPacket indicates an expected call of ReadPacket. func (mr *MockRawConnMockRecorder) ReadPacket() *MockRawConnReadPacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadPacket", reflect.TypeOf((*MockRawConn)(nil).ReadPacket)) return &MockRawConnReadPacketCall{Call: call} } // MockRawConnReadPacketCall wrap *gomock.Call type MockRawConnReadPacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockRawConnReadPacketCall) Return(arg0 receivedPacket, arg1 error) *MockRawConnReadPacketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockRawConnReadPacketCall) Do(f func() (receivedPacket, error)) *MockRawConnReadPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockRawConnReadPacketCall) DoAndReturn(f func() (receivedPacket, error)) *MockRawConnReadPacketCall { c.Call = c.Call.DoAndReturn(f) return c } // SetReadDeadline mocks base method. func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetReadDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetReadDeadline indicates an expected call of SetReadDeadline. func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 any) *MockRawConnSetReadDeadlineCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) return &MockRawConnSetReadDeadlineCall{Call: call} } // MockRawConnSetReadDeadlineCall wrap *gomock.Call type MockRawConnSetReadDeadlineCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockRawConnSetReadDeadlineCall) Return(arg0 error) *MockRawConnSetReadDeadlineCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockRawConnSetReadDeadlineCall) Do(f func(time.Time) error) *MockRawConnSetReadDeadlineCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockRawConnSetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockRawConnSetReadDeadlineCall { c.Call = c.Call.DoAndReturn(f) return c } // WritePacket mocks base method. func (m *MockRawConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WritePacket", b, addr, packetInfoOOB, gsoSize, ecn) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WritePacket indicates an expected call of WritePacket. func (mr *MockRawConnMockRecorder) WritePacket(b, addr, packetInfoOOB, gsoSize, ecn any) *MockRawConnWritePacketCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), b, addr, packetInfoOOB, gsoSize, ecn) return &MockRawConnWritePacketCall{Call: call} } // MockRawConnWritePacketCall wrap *gomock.Call type MockRawConnWritePacketCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockRawConnWritePacketCall) Return(arg0 int, arg1 error) *MockRawConnWritePacketCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockRawConnWritePacketCall) Do(f func([]byte, net.Addr, []byte, uint16, protocol.ECN) (int, error)) *MockRawConnWritePacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockRawConnWritePacketCall) DoAndReturn(f func([]byte, net.Addr, []byte, uint16, protocol.ECN) (int, error)) *MockRawConnWritePacketCall { c.Call = c.Call.DoAndReturn(f) return c } // capabilities mocks base method. func (m *MockRawConn) capabilities() connCapabilities { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "capabilities") ret0, _ := ret[0].(connCapabilities) return ret0 } // capabilities indicates an expected call of capabilities. func (mr *MockRawConnMockRecorder) capabilities() *MockRawConncapabilitiesCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockRawConn)(nil).capabilities)) return &MockRawConncapabilitiesCall{Call: call} } // MockRawConncapabilitiesCall wrap *gomock.Call type MockRawConncapabilitiesCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockRawConncapabilitiesCall) Return(arg0 connCapabilities) *MockRawConncapabilitiesCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockRawConncapabilitiesCall) Do(f func() connCapabilities) *MockRawConncapabilitiesCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockRawConncapabilitiesCall) DoAndReturn(f func() connCapabilities) *MockRawConncapabilitiesCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_sealing_manager_test.go000066400000000000000000000157621513066070600207030ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: SealingManager) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" handshake "github.com/quic-go/quic-go/internal/handshake" gomock "go.uber.org/mock/gomock" ) // MockSealingManager is a mock of SealingManager interface. type MockSealingManager struct { ctrl *gomock.Controller recorder *MockSealingManagerMockRecorder isgomock struct{} } // MockSealingManagerMockRecorder is the mock recorder for MockSealingManager. type MockSealingManagerMockRecorder struct { mock *MockSealingManager } // NewMockSealingManager creates a new mock instance. func NewMockSealingManager(ctrl *gomock.Controller) *MockSealingManager { mock := &MockSealingManager{ctrl: ctrl} mock.recorder = &MockSealingManagerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockSealingManager) EXPECT() *MockSealingManagerMockRecorder { return m.recorder } // Get0RTTSealer mocks base method. func (m *MockSealingManager) Get0RTTSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get0RTTSealer") ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // Get0RTTSealer indicates an expected call of Get0RTTSealer. func (mr *MockSealingManagerMockRecorder) Get0RTTSealer() *MockSealingManagerGet0RTTSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get0RTTSealer)) return &MockSealingManagerGet0RTTSealerCall{Call: call} } // MockSealingManagerGet0RTTSealerCall wrap *gomock.Call type MockSealingManagerGet0RTTSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSealingManagerGet0RTTSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockSealingManagerGet0RTTSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockSealingManagerGet0RTTSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGet0RTTSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSealingManagerGet0RTTSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGet0RTTSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // Get1RTTSealer mocks base method. func (m *MockSealingManager) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get1RTTSealer") ret0, _ := ret[0].(handshake.ShortHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // Get1RTTSealer indicates an expected call of Get1RTTSealer. func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *MockSealingManagerGet1RTTSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get1RTTSealer)) return &MockSealingManagerGet1RTTSealerCall{Call: call} } // MockSealingManagerGet1RTTSealerCall wrap *gomock.Call type MockSealingManagerGet1RTTSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSealingManagerGet1RTTSealerCall) Return(arg0 handshake.ShortHeaderSealer, arg1 error) *MockSealingManagerGet1RTTSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockSealingManagerGet1RTTSealerCall) Do(f func() (handshake.ShortHeaderSealer, error)) *MockSealingManagerGet1RTTSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSealingManagerGet1RTTSealerCall) DoAndReturn(f func() (handshake.ShortHeaderSealer, error)) *MockSealingManagerGet1RTTSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetHandshakeSealer mocks base method. func (m *MockSealingManager) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetHandshakeSealer") ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHandshakeSealer indicates an expected call of GetHandshakeSealer. func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *MockSealingManagerGetHandshakeSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockSealingManager)(nil).GetHandshakeSealer)) return &MockSealingManagerGetHandshakeSealerCall{Call: call} } // MockSealingManagerGetHandshakeSealerCall wrap *gomock.Call type MockSealingManagerGetHandshakeSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSealingManagerGetHandshakeSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockSealingManagerGetHandshakeSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockSealingManagerGetHandshakeSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetHandshakeSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSealingManagerGetHandshakeSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetHandshakeSealerCall { c.Call = c.Call.DoAndReturn(f) return c } // GetInitialSealer mocks base method. func (m *MockSealingManager) GetInitialSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInitialSealer") ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } // GetInitialSealer indicates an expected call of GetInitialSealer. func (mr *MockSealingManagerMockRecorder) GetInitialSealer() *MockSealingManagerGetInitialSealerCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockSealingManager)(nil).GetInitialSealer)) return &MockSealingManagerGetInitialSealerCall{Call: call} } // MockSealingManagerGetInitialSealerCall wrap *gomock.Call type MockSealingManagerGetInitialSealerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSealingManagerGetInitialSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockSealingManagerGetInitialSealerCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockSealingManagerGetInitialSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetInitialSealerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSealingManagerGetInitialSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetInitialSealerCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_send_conn_test.go000066400000000000000000000220501513066070600175210ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: SendConn) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn // // Package quic is a generated GoMock package. package quic import ( net "net" reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockSendConn is a mock of SendConn interface. type MockSendConn struct { ctrl *gomock.Controller recorder *MockSendConnMockRecorder isgomock struct{} } // MockSendConnMockRecorder is the mock recorder for MockSendConn. type MockSendConnMockRecorder struct { mock *MockSendConn } // NewMockSendConn creates a new mock instance. func NewMockSendConn(ctrl *gomock.Controller) *MockSendConn { mock := &MockSendConn{ctrl: ctrl} mock.recorder = &MockSendConnMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockSendConn) EXPECT() *MockSendConnMockRecorder { return m.recorder } // ChangeRemoteAddr mocks base method. func (m *MockSendConn) ChangeRemoteAddr(addr net.Addr, info packetInfo) { m.ctrl.T.Helper() m.ctrl.Call(m, "ChangeRemoteAddr", addr, info) } // ChangeRemoteAddr indicates an expected call of ChangeRemoteAddr. func (mr *MockSendConnMockRecorder) ChangeRemoteAddr(addr, info any) *MockSendConnChangeRemoteAddrCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeRemoteAddr", reflect.TypeOf((*MockSendConn)(nil).ChangeRemoteAddr), addr, info) return &MockSendConnChangeRemoteAddrCall{Call: call} } // MockSendConnChangeRemoteAddrCall wrap *gomock.Call type MockSendConnChangeRemoteAddrCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConnChangeRemoteAddrCall) Return() *MockSendConnChangeRemoteAddrCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSendConnChangeRemoteAddrCall) Do(f func(net.Addr, packetInfo)) *MockSendConnChangeRemoteAddrCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConnChangeRemoteAddrCall) DoAndReturn(f func(net.Addr, packetInfo)) *MockSendConnChangeRemoteAddrCall { c.Call = c.Call.DoAndReturn(f) return c } // Close mocks base method. func (m *MockSendConn) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. func (mr *MockSendConnMockRecorder) Close() *MockSendConnCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendConn)(nil).Close)) return &MockSendConnCloseCall{Call: call} } // MockSendConnCloseCall wrap *gomock.Call type MockSendConnCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConnCloseCall) Return(arg0 error) *MockSendConnCloseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendConnCloseCall) Do(f func() error) *MockSendConnCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConnCloseCall) DoAndReturn(f func() error) *MockSendConnCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // LocalAddr mocks base method. func (m *MockSendConn) LocalAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LocalAddr") ret0, _ := ret[0].(net.Addr) return ret0 } // LocalAddr indicates an expected call of LocalAddr. func (mr *MockSendConnMockRecorder) LocalAddr() *MockSendConnLocalAddrCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSendConn)(nil).LocalAddr)) return &MockSendConnLocalAddrCall{Call: call} } // MockSendConnLocalAddrCall wrap *gomock.Call type MockSendConnLocalAddrCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConnLocalAddrCall) Return(arg0 net.Addr) *MockSendConnLocalAddrCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendConnLocalAddrCall) Do(f func() net.Addr) *MockSendConnLocalAddrCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockSendConnLocalAddrCall { c.Call = c.Call.DoAndReturn(f) return c } // RemoteAddr mocks base method. func (m *MockSendConn) RemoteAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoteAddr") ret0, _ := ret[0].(net.Addr) return ret0 } // RemoteAddr indicates an expected call of RemoteAddr. func (mr *MockSendConnMockRecorder) RemoteAddr() *MockSendConnRemoteAddrCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSendConn)(nil).RemoteAddr)) return &MockSendConnRemoteAddrCall{Call: call} } // MockSendConnRemoteAddrCall wrap *gomock.Call type MockSendConnRemoteAddrCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConnRemoteAddrCall) Return(arg0 net.Addr) *MockSendConnRemoteAddrCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendConnRemoteAddrCall) Do(f func() net.Addr) *MockSendConnRemoteAddrCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConnRemoteAddrCall) DoAndReturn(f func() net.Addr) *MockSendConnRemoteAddrCall { c.Call = c.Call.DoAndReturn(f) return c } // Write mocks base method. func (m *MockSendConn) Write(b []byte, gsoSize uint16, ecn protocol.ECN) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", b, gsoSize, ecn) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. func (mr *MockSendConnMockRecorder) Write(b, gsoSize, ecn any) *MockSendConnWriteCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), b, gsoSize, ecn) return &MockSendConnWriteCall{Call: call} } // MockSendConnWriteCall wrap *gomock.Call type MockSendConnWriteCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConnWriteCall) Return(arg0 error) *MockSendConnWriteCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendConnWriteCall) Do(f func([]byte, uint16, protocol.ECN) error) *MockSendConnWriteCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConnWriteCall) DoAndReturn(f func([]byte, uint16, protocol.ECN) error) *MockSendConnWriteCall { c.Call = c.Call.DoAndReturn(f) return c } // WriteTo mocks base method. func (m *MockSendConn) WriteTo(arg0 []byte, arg1 net.Addr) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WriteTo", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // WriteTo indicates an expected call of WriteTo. func (mr *MockSendConnMockRecorder) WriteTo(arg0, arg1 any) *MockSendConnWriteToCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockSendConn)(nil).WriteTo), arg0, arg1) return &MockSendConnWriteToCall{Call: call} } // MockSendConnWriteToCall wrap *gomock.Call type MockSendConnWriteToCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConnWriteToCall) Return(arg0 error) *MockSendConnWriteToCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendConnWriteToCall) Do(f func([]byte, net.Addr) error) *MockSendConnWriteToCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConnWriteToCall) DoAndReturn(f func([]byte, net.Addr) error) *MockSendConnWriteToCall { c.Call = c.Call.DoAndReturn(f) return c } // capabilities mocks base method. func (m *MockSendConn) capabilities() connCapabilities { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "capabilities") ret0, _ := ret[0].(connCapabilities) return ret0 } // capabilities indicates an expected call of capabilities. func (mr *MockSendConnMockRecorder) capabilities() *MockSendConncapabilitiesCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockSendConn)(nil).capabilities)) return &MockSendConncapabilitiesCall{Call: call} } // MockSendConncapabilitiesCall wrap *gomock.Call type MockSendConncapabilitiesCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSendConncapabilitiesCall) Return(arg0 connCapabilities) *MockSendConncapabilitiesCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSendConncapabilitiesCall) Do(f func() connCapabilities) *MockSendConncapabilitiesCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSendConncapabilitiesCall) DoAndReturn(f func() connCapabilities) *MockSendConncapabilitiesCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_sender_test.go000066400000000000000000000164041513066070600170410ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Sender) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender // // Package quic is a generated GoMock package. package quic import ( net "net" reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockSender is a mock of Sender interface. type MockSender struct { ctrl *gomock.Controller recorder *MockSenderMockRecorder isgomock struct{} } // MockSenderMockRecorder is the mock recorder for MockSender. type MockSenderMockRecorder struct { mock *MockSender } // NewMockSender creates a new mock instance. func NewMockSender(ctrl *gomock.Controller) *MockSender { mock := &MockSender{ctrl: ctrl} mock.recorder = &MockSenderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockSender) EXPECT() *MockSenderMockRecorder { return m.recorder } // Available mocks base method. func (m *MockSender) Available() <-chan struct{} { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Available") ret0, _ := ret[0].(<-chan struct{}) return ret0 } // Available indicates an expected call of Available. func (mr *MockSenderMockRecorder) Available() *MockSenderAvailableCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockSender)(nil).Available)) return &MockSenderAvailableCall{Call: call} } // MockSenderAvailableCall wrap *gomock.Call type MockSenderAvailableCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSenderAvailableCall) Return(arg0 <-chan struct{}) *MockSenderAvailableCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSenderAvailableCall) Do(f func() <-chan struct{}) *MockSenderAvailableCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSenderAvailableCall) DoAndReturn(f func() <-chan struct{}) *MockSenderAvailableCall { c.Call = c.Call.DoAndReturn(f) return c } // Close mocks base method. func (m *MockSender) Close() { m.ctrl.T.Helper() m.ctrl.Call(m, "Close") } // Close indicates an expected call of Close. func (mr *MockSenderMockRecorder) Close() *MockSenderCloseCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSender)(nil).Close)) return &MockSenderCloseCall{Call: call} } // MockSenderCloseCall wrap *gomock.Call type MockSenderCloseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSenderCloseCall) Return() *MockSenderCloseCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSenderCloseCall) Do(f func()) *MockSenderCloseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSenderCloseCall) DoAndReturn(f func()) *MockSenderCloseCall { c.Call = c.Call.DoAndReturn(f) return c } // Run mocks base method. func (m *MockSender) Run() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Run") ret0, _ := ret[0].(error) return ret0 } // Run indicates an expected call of Run. func (mr *MockSenderMockRecorder) Run() *MockSenderRunCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSender)(nil).Run)) return &MockSenderRunCall{Call: call} } // MockSenderRunCall wrap *gomock.Call type MockSenderRunCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSenderRunCall) Return(arg0 error) *MockSenderRunCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSenderRunCall) Do(f func() error) *MockSenderRunCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSenderRunCall) DoAndReturn(f func() error) *MockSenderRunCall { c.Call = c.Call.DoAndReturn(f) return c } // Send mocks base method. func (m *MockSender) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { m.ctrl.T.Helper() m.ctrl.Call(m, "Send", p, gsoSize, ecn) } // Send indicates an expected call of Send. func (mr *MockSenderMockRecorder) Send(p, gsoSize, ecn any) *MockSenderSendCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), p, gsoSize, ecn) return &MockSenderSendCall{Call: call} } // MockSenderSendCall wrap *gomock.Call type MockSenderSendCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSenderSendCall) Return() *MockSenderSendCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSenderSendCall) Do(f func(*packetBuffer, uint16, protocol.ECN)) *MockSenderSendCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSenderSendCall) DoAndReturn(f func(*packetBuffer, uint16, protocol.ECN)) *MockSenderSendCall { c.Call = c.Call.DoAndReturn(f) return c } // SendProbe mocks base method. func (m *MockSender) SendProbe(arg0 *packetBuffer, arg1 net.Addr) { m.ctrl.T.Helper() m.ctrl.Call(m, "SendProbe", arg0, arg1) } // SendProbe indicates an expected call of SendProbe. func (mr *MockSenderMockRecorder) SendProbe(arg0, arg1 any) *MockSenderSendProbeCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendProbe", reflect.TypeOf((*MockSender)(nil).SendProbe), arg0, arg1) return &MockSenderSendProbeCall{Call: call} } // MockSenderSendProbeCall wrap *gomock.Call type MockSenderSendProbeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSenderSendProbeCall) Return() *MockSenderSendProbeCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockSenderSendProbeCall) Do(f func(*packetBuffer, net.Addr)) *MockSenderSendProbeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSenderSendProbeCall) DoAndReturn(f func(*packetBuffer, net.Addr)) *MockSenderSendProbeCall { c.Call = c.Call.DoAndReturn(f) return c } // WouldBlock mocks base method. func (m *MockSender) WouldBlock() bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WouldBlock") ret0, _ := ret[0].(bool) return ret0 } // WouldBlock indicates an expected call of WouldBlock. func (mr *MockSenderMockRecorder) WouldBlock() *MockSenderWouldBlockCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WouldBlock", reflect.TypeOf((*MockSender)(nil).WouldBlock)) return &MockSenderWouldBlockCall{Call: call} } // MockSenderWouldBlockCall wrap *gomock.Call type MockSenderWouldBlockCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockSenderWouldBlockCall) Return(arg0 bool) *MockSenderWouldBlockCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do func (c *MockSenderWouldBlockCall) Do(f func() bool) *MockSenderWouldBlockCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockSenderWouldBlockCall) DoAndReturn(f func() bool) *MockSenderWouldBlockCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_stream_control_frame_getter_test.go000066400000000000000000000061571513066070600233440ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamControlFrameGetter) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" monotime "github.com/quic-go/quic-go/internal/monotime" gomock "go.uber.org/mock/gomock" ) // MockStreamControlFrameGetter is a mock of StreamControlFrameGetter interface. type MockStreamControlFrameGetter struct { ctrl *gomock.Controller recorder *MockStreamControlFrameGetterMockRecorder isgomock struct{} } // MockStreamControlFrameGetterMockRecorder is the mock recorder for MockStreamControlFrameGetter. type MockStreamControlFrameGetterMockRecorder struct { mock *MockStreamControlFrameGetter } // NewMockStreamControlFrameGetter creates a new mock instance. func NewMockStreamControlFrameGetter(ctrl *gomock.Controller) *MockStreamControlFrameGetter { mock := &MockStreamControlFrameGetter{ctrl: ctrl} mock.recorder = &MockStreamControlFrameGetterMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockStreamControlFrameGetter) EXPECT() *MockStreamControlFrameGetterMockRecorder { return m.recorder } // getControlFrame mocks base method. func (m *MockStreamControlFrameGetter) getControlFrame(arg0 monotime.Time) (ackhandler.Frame, bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "getControlFrame", arg0) ret0, _ := ret[0].(ackhandler.Frame) ret1, _ := ret[1].(bool) ret2, _ := ret[2].(bool) return ret0, ret1, ret2 } // getControlFrame indicates an expected call of getControlFrame. func (mr *MockStreamControlFrameGetterMockRecorder) getControlFrame(arg0 any) *MockStreamControlFrameGettergetControlFrameCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getControlFrame", reflect.TypeOf((*MockStreamControlFrameGetter)(nil).getControlFrame), arg0) return &MockStreamControlFrameGettergetControlFrameCall{Call: call} } // MockStreamControlFrameGettergetControlFrameCall wrap *gomock.Call type MockStreamControlFrameGettergetControlFrameCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamControlFrameGettergetControlFrameCall) Return(arg0 ackhandler.Frame, ok, hasMore bool) *MockStreamControlFrameGettergetControlFrameCall { c.Call = c.Call.Return(arg0, ok, hasMore) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamControlFrameGettergetControlFrameCall) Do(f func(monotime.Time) (ackhandler.Frame, bool, bool)) *MockStreamControlFrameGettergetControlFrameCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamControlFrameGettergetControlFrameCall) DoAndReturn(f func(monotime.Time) (ackhandler.Frame, bool, bool)) *MockStreamControlFrameGettergetControlFrameCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_stream_frame_getter_test.go000066400000000000000000000062511513066070600215770ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamFrameGetter) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_frame_getter_test.go github.com/quic-go/quic-go StreamFrameGetter // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" protocol "github.com/quic-go/quic-go/internal/protocol" wire "github.com/quic-go/quic-go/internal/wire" gomock "go.uber.org/mock/gomock" ) // MockStreamFrameGetter is a mock of StreamFrameGetter interface. type MockStreamFrameGetter struct { ctrl *gomock.Controller recorder *MockStreamFrameGetterMockRecorder isgomock struct{} } // MockStreamFrameGetterMockRecorder is the mock recorder for MockStreamFrameGetter. type MockStreamFrameGetterMockRecorder struct { mock *MockStreamFrameGetter } // NewMockStreamFrameGetter creates a new mock instance. func NewMockStreamFrameGetter(ctrl *gomock.Controller) *MockStreamFrameGetter { mock := &MockStreamFrameGetter{ctrl: ctrl} mock.recorder = &MockStreamFrameGetterMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockStreamFrameGetter) EXPECT() *MockStreamFrameGetterMockRecorder { return m.recorder } // popStreamFrame mocks base method. func (m *MockStreamFrameGetter) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1) ret0, _ := ret[0].(ackhandler.StreamFrame) ret1, _ := ret[1].(*wire.StreamDataBlockedFrame) ret2, _ := ret[2].(bool) return ret0, ret1, ret2 } // popStreamFrame indicates an expected call of popStreamFrame. func (mr *MockStreamFrameGetterMockRecorder) popStreamFrame(arg0, arg1 any) *MockStreamFrameGetterpopStreamFrameCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamFrameGetter)(nil).popStreamFrame), arg0, arg1) return &MockStreamFrameGetterpopStreamFrameCall{Call: call} } // MockStreamFrameGetterpopStreamFrameCall wrap *gomock.Call type MockStreamFrameGetterpopStreamFrameCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamFrameGetterpopStreamFrameCall) Return(arg0 ackhandler.StreamFrame, arg1 *wire.StreamDataBlockedFrame, arg2 bool) *MockStreamFrameGetterpopStreamFrameCall { c.Call = c.Call.Return(arg0, arg1, arg2) return c } // Do rewrite *gomock.Call.Do func (c *MockStreamFrameGetterpopStreamFrameCall) Do(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)) *MockStreamFrameGetterpopStreamFrameCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamFrameGetterpopStreamFrameCall) DoAndReturn(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)) *MockStreamFrameGetterpopStreamFrameCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_stream_sender_test.go000066400000000000000000000147121513066070600204140ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: StreamSender) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) // MockStreamSender is a mock of StreamSender interface. type MockStreamSender struct { ctrl *gomock.Controller recorder *MockStreamSenderMockRecorder isgomock struct{} } // MockStreamSenderMockRecorder is the mock recorder for MockStreamSender. type MockStreamSenderMockRecorder struct { mock *MockStreamSender } // NewMockStreamSender creates a new mock instance. func NewMockStreamSender(ctrl *gomock.Controller) *MockStreamSender { mock := &MockStreamSender{ctrl: ctrl} mock.recorder = &MockStreamSenderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockStreamSender) EXPECT() *MockStreamSenderMockRecorder { return m.recorder } // onHasConnectionData mocks base method. func (m *MockStreamSender) onHasConnectionData() { m.ctrl.T.Helper() m.ctrl.Call(m, "onHasConnectionData") } // onHasConnectionData indicates an expected call of onHasConnectionData. func (mr *MockStreamSenderMockRecorder) onHasConnectionData() *MockStreamSenderonHasConnectionDataCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasConnectionData", reflect.TypeOf((*MockStreamSender)(nil).onHasConnectionData)) return &MockStreamSenderonHasConnectionDataCall{Call: call} } // MockStreamSenderonHasConnectionDataCall wrap *gomock.Call type MockStreamSenderonHasConnectionDataCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamSenderonHasConnectionDataCall) Return() *MockStreamSenderonHasConnectionDataCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockStreamSenderonHasConnectionDataCall) Do(f func()) *MockStreamSenderonHasConnectionDataCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamSenderonHasConnectionDataCall) DoAndReturn(f func()) *MockStreamSenderonHasConnectionDataCall { c.Call = c.Call.DoAndReturn(f) return c } // onHasStreamControlFrame mocks base method. func (m *MockStreamSender) onHasStreamControlFrame(arg0 protocol.StreamID, arg1 streamControlFrameGetter) { m.ctrl.T.Helper() m.ctrl.Call(m, "onHasStreamControlFrame", arg0, arg1) } // onHasStreamControlFrame indicates an expected call of onHasStreamControlFrame. func (mr *MockStreamSenderMockRecorder) onHasStreamControlFrame(arg0, arg1 any) *MockStreamSenderonHasStreamControlFrameCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamControlFrame", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamControlFrame), arg0, arg1) return &MockStreamSenderonHasStreamControlFrameCall{Call: call} } // MockStreamSenderonHasStreamControlFrameCall wrap *gomock.Call type MockStreamSenderonHasStreamControlFrameCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamSenderonHasStreamControlFrameCall) Return() *MockStreamSenderonHasStreamControlFrameCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockStreamSenderonHasStreamControlFrameCall) Do(f func(protocol.StreamID, streamControlFrameGetter)) *MockStreamSenderonHasStreamControlFrameCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamSenderonHasStreamControlFrameCall) DoAndReturn(f func(protocol.StreamID, streamControlFrameGetter)) *MockStreamSenderonHasStreamControlFrameCall { c.Call = c.Call.DoAndReturn(f) return c } // onHasStreamData mocks base method. func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID, arg1 *SendStream) { m.ctrl.T.Helper() m.ctrl.Call(m, "onHasStreamData", arg0, arg1) } // onHasStreamData indicates an expected call of onHasStreamData. func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0, arg1 any) *MockStreamSenderonHasStreamDataCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0, arg1) return &MockStreamSenderonHasStreamDataCall{Call: call} } // MockStreamSenderonHasStreamDataCall wrap *gomock.Call type MockStreamSenderonHasStreamDataCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamSenderonHasStreamDataCall) Return() *MockStreamSenderonHasStreamDataCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID, *SendStream)) *MockStreamSenderonHasStreamDataCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID, *SendStream)) *MockStreamSenderonHasStreamDataCall { c.Call = c.Call.DoAndReturn(f) return c } // onStreamCompleted mocks base method. func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { m.ctrl.T.Helper() m.ctrl.Call(m, "onStreamCompleted", arg0) } // onStreamCompleted indicates an expected call of onStreamCompleted. func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 any) *MockStreamSenderonStreamCompletedCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) return &MockStreamSenderonStreamCompletedCall{Call: call} } // MockStreamSenderonStreamCompletedCall wrap *gomock.Call type MockStreamSenderonStreamCompletedCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockStreamSenderonStreamCompletedCall) Return() *MockStreamSenderonStreamCompletedCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do func (c *MockStreamSenderonStreamCompletedCall) Do(f func(protocol.StreamID)) *MockStreamSenderonStreamCompletedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockStreamSenderonStreamCompletedCall) DoAndReturn(f func(protocol.StreamID)) *MockStreamSenderonStreamCompletedCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mock_unpacker_test.go000066400000000000000000000110401513066070600173600ustar00rootroot00000000000000// Code generated by MockGen. DO NOT EDIT. // Source: github.com/quic-go/quic-go (interfaces: Unpacker) // // Generated by this command: // // mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker // // Package quic is a generated GoMock package. package quic import ( reflect "reflect" monotime "github.com/quic-go/quic-go/internal/monotime" protocol "github.com/quic-go/quic-go/internal/protocol" wire "github.com/quic-go/quic-go/internal/wire" gomock "go.uber.org/mock/gomock" ) // MockUnpacker is a mock of Unpacker interface. type MockUnpacker struct { ctrl *gomock.Controller recorder *MockUnpackerMockRecorder isgomock struct{} } // MockUnpackerMockRecorder is the mock recorder for MockUnpacker. type MockUnpackerMockRecorder struct { mock *MockUnpacker } // NewMockUnpacker creates a new mock instance. func NewMockUnpacker(ctrl *gomock.Controller) *MockUnpacker { mock := &MockUnpacker{ctrl: ctrl} mock.recorder = &MockUnpackerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder { return m.recorder } // UnpackLongHeader mocks base method. func (m *MockUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UnpackLongHeader", hdr, data) ret0, _ := ret[0].(*unpackedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // UnpackLongHeader indicates an expected call of UnpackLongHeader. func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, data any) *MockUnpackerUnpackLongHeaderCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), hdr, data) return &MockUnpackerUnpackLongHeaderCall{Call: call} } // MockUnpackerUnpackLongHeaderCall wrap *gomock.Call type MockUnpackerUnpackLongHeaderCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockUnpackerUnpackLongHeaderCall) Return(arg0 *unpackedPacket, arg1 error) *MockUnpackerUnpackLongHeaderCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do func (c *MockUnpackerUnpackLongHeaderCall) Do(f func(*wire.Header, []byte) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockUnpackerUnpackLongHeaderCall) DoAndReturn(f func(*wire.Header, []byte) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall { c.Call = c.Call.DoAndReturn(f) return c } // UnpackShortHeader mocks base method. func (m *MockUnpacker) UnpackShortHeader(rcvTime monotime.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data) ret0, _ := ret[0].(protocol.PacketNumber) ret1, _ := ret[1].(protocol.PacketNumberLen) ret2, _ := ret[2].(protocol.KeyPhaseBit) ret3, _ := ret[3].([]byte) ret4, _ := ret[4].(error) return ret0, ret1, ret2, ret3, ret4 } // UnpackShortHeader indicates an expected call of UnpackShortHeader. func (mr *MockUnpackerMockRecorder) UnpackShortHeader(rcvTime, data any) *MockUnpackerUnpackShortHeaderCall { mr.mock.ctrl.T.Helper() call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), rcvTime, data) return &MockUnpackerUnpackShortHeaderCall{Call: call} } // MockUnpackerUnpackShortHeaderCall wrap *gomock.Call type MockUnpackerUnpackShortHeaderCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return func (c *MockUnpackerUnpackShortHeaderCall) Return(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen, arg2 protocol.KeyPhaseBit, arg3 []byte, arg4 error) *MockUnpackerUnpackShortHeaderCall { c.Call = c.Call.Return(arg0, arg1, arg2, arg3, arg4) return c } // Do rewrite *gomock.Call.Do func (c *MockUnpackerUnpackShortHeaderCall) Do(f func(monotime.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)) *MockUnpackerUnpackShortHeaderCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn func (c *MockUnpackerUnpackShortHeaderCall) DoAndReturn(f func(monotime.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)) *MockUnpackerUnpackShortHeaderCall { c.Call = c.Call.DoAndReturn(f) return c } quic-go-0.59.0/mockgen.go000066400000000000000000000070631513066070600151350ustar00rootroot00000000000000//go:build gomock || generate package quic //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" type SendConn = sendConn //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn" type RawConn = rawConn //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" type Sender = sender //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" type StreamSender = streamSender //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter" type StreamControlFrameGetter = streamControlFrameGetter //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_frame_getter_test.go github.com/quic-go/quic-go StreamFrameGetter" type StreamFrameGetter = streamFrameGetter //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" type FrameSource = frameSource //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" type AckFrameSource = ackFrameSource //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" type SealingManager = sealingManager //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker" type Unpacker = unpacker //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer" type Packer = packer //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer" type MTUDiscoverer = mtuDiscoverer //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner" type ConnRunner = connRunner //go:generate sh -c "go tool mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" type PacketHandler = packetHandler //go:generate sh -c "go tool mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" quic-go-0.59.0/mtu_discoverer.go000066400000000000000000000177031513066070600165460ustar00rootroot00000000000000package quic import ( "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) type mtuDiscoverer interface { // Start starts the MTU discovery process. // It's unnecessary to call ShouldSendProbe before that. Start(now monotime.Time) ShouldSendProbe(now monotime.Time) bool CurrentSize() protocol.ByteCount GetPing(now monotime.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount) Reset(now monotime.Time, start, max protocol.ByteCount) } const ( // At some point, we have to stop searching for a higher MTU. // We're happy to send a packet that's 10 bytes smaller than the actual MTU. maxMTUDiff protocol.ByteCount = 20 // send a probe packet every mtuProbeDelay RTTs mtuProbeDelay = 5 // Once maxLostMTUProbes MTU probe packets larger than a certain size are lost, // MTU discovery won't probe for larger MTUs than this size. // The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets. maxLostMTUProbes = 3 ) // The Path MTU is found by sending a larger packet every now and then. // If the packet is acknowledged, we conclude that the path supports this larger packet size. // If the packet is lost, this can mean one of two things: // 1. The path doesn't support this larger packet size, or // 2. The packet was lost due to packet loss, independent of its size. // The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets. // For simplicty, the following example use maxLostMTUProbes = 2. // // Initialization: // |------------------------------------------------------------------------------| // min max // // The first MTU probe packet will have size (min+max)/2. // Assume that this packet is acknowledged. We can now move the min marker, // and continue the search in the resulting interval. // // If 1st probe packet acknowledged: // |---------------------------------------|--------------------------------------| // min max // // If 1st probe packet lost: // |---------------------------------------|--------------------------------------| // min lost[0] max // // We can't conclude that the path doesn't support this packet size, since the loss of the probe // packet could have been unrelated to the packet size. A larger probe packet will be sent later on. // After a loss, the next probe packet has size (min+lost[0])/2. // Now assume this probe packet is acknowledged: // // 2nd probe packet acknowledged: // |------------------|--------------------|--------------------------------------| // min lost[0] max // // First of all, we conclude that the path supports at least this MTU. That's progress! // Second, we probe a bit more aggressively with the next probe packet: // After an acknowledgement, the next probe packet has size (min+max)/2. // This means we'll send a packet larger than the first probe packet (which was lost). // // If 3rd probe packet acknowledged: // |-------------------------------------------------|----------------------------| // min max // // We can conclude that the loss of the 1st probe packet was not due to its size, and // continue searching in a much smaller interval now. // // If 3rd probe packet lost: // |------------------|--------------------|---------|----------------------------| // min lost[0] max // // Since in our example numPTOProbes = 2, and we lost 2 packets smaller than max, we // conclude that this packet size is not supported on the path, and reduce the maximum // value of the search interval. // // MTU discovery concludes once the interval min and max has been narrowed down to maxMTUDiff. type mtuFinder struct { lastProbeTime monotime.Time rttStats *utils.RTTStats inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight min protocol.ByteCount // on initialization, we treat the maximum size as the first "lost" packet lost [maxLostMTUProbes]protocol.ByteCount lastProbeWasLost bool // The generation is used to ignore ACKs / losses for probe packets sent before a reset. // Resets happen when the connection is migrated to a new path. // We're therefore not concerned about overflows of this counter. generation uint8 qlogger qlogwriter.Recorder } var _ mtuDiscoverer = &mtuFinder{} func newMTUDiscoverer( rttStats *utils.RTTStats, start, max protocol.ByteCount, qlogger qlogwriter.Recorder, ) *mtuFinder { f := &mtuFinder{ inFlight: protocol.InvalidByteCount, rttStats: rttStats, qlogger: qlogger, } f.init(start, max) return f } func (f *mtuFinder) init(start, max protocol.ByteCount) { f.min = start for i := range f.lost { if i == 0 { f.lost[i] = max continue } f.lost[i] = protocol.InvalidByteCount } } func (f *mtuFinder) done() bool { return f.max()-f.min <= maxMTUDiff+1 } func (f *mtuFinder) max() protocol.ByteCount { for i, v := range f.lost { if v == protocol.InvalidByteCount { return f.lost[i-1] } } return f.lost[len(f.lost)-1] } func (f *mtuFinder) Start(now monotime.Time) { f.lastProbeTime = now // makes sure the first probe packet is not sent immediately } func (f *mtuFinder) ShouldSendProbe(now monotime.Time) bool { if f.lastProbeTime.IsZero() { return false } if f.inFlight != protocol.InvalidByteCount || f.done() { return false } return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) } func (f *mtuFinder) GetPing(now monotime.Time) (ackhandler.Frame, protocol.ByteCount) { var size protocol.ByteCount if f.lastProbeWasLost { size = (f.min + f.lost[0]) / 2 } else { size = (f.min + f.max()) / 2 } f.lastProbeTime = now f.inFlight = size return ackhandler.Frame{ Frame: &wire.PingFrame{}, Handler: &mtuFinderAckHandler{mtuFinder: f, generation: f.generation}, }, size } func (f *mtuFinder) CurrentSize() protocol.ByteCount { return f.min } func (f *mtuFinder) Reset(now monotime.Time, start, max protocol.ByteCount) { f.generation++ f.lastProbeTime = now f.lastProbeWasLost = false f.inFlight = protocol.InvalidByteCount f.init(start, max) } type mtuFinderAckHandler struct { *mtuFinder generation uint8 } var _ ackhandler.FrameHandler = &mtuFinderAckHandler{} func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { if h.generation != h.mtuFinder.generation { // ACK for probe sent before reset return } size := h.inFlight if size == protocol.InvalidByteCount { panic("OnAcked callback called although there's no MTU probe packet in flight") } h.inFlight = protocol.InvalidByteCount h.min = size h.lastProbeWasLost = false // remove all values smaller than size from the lost array var j int for i, v := range h.lost { if size < v { j = i break } } if j > 0 { for i := 0; i < len(h.lost); i++ { if i+j < len(h.lost) { h.lost[i] = h.lost[i+j] } else { h.lost[i] = protocol.InvalidByteCount } } } if h.qlogger != nil { h.qlogger.RecordEvent(qlog.MTUUpdated{ Value: int(size), Done: h.done(), }) } } func (h *mtuFinderAckHandler) OnLost(wire.Frame) { if h.generation != h.mtuFinder.generation { // probe sent before reset received return } size := h.inFlight if size == protocol.InvalidByteCount { panic("OnLost callback called although there's no MTU probe packet in flight") } h.lastProbeWasLost = true h.inFlight = protocol.InvalidByteCount for i, v := range h.lost { if size < v { copy(h.lost[i+1:], h.lost[i:]) h.lost[i] = size break } } } quic-go-0.59.0/mtu_discoverer_test.go000066400000000000000000000166261513066070600176100ustar00rootroot00000000000000package quic import ( "fmt" "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/require" ) func TestMTUDiscovererTiming(t *testing.T) { const rtt = 100 * time.Millisecond rttStats := utils.NewRTTStats() rttStats.UpdateRTT(rtt, 0) d := newMTUDiscoverer(rttStats, 1000, 2000, nil) now := monotime.Now() require.False(t, d.ShouldSendProbe(now)) d.Start(now) require.False(t, d.ShouldSendProbe(now)) require.False(t, d.ShouldSendProbe(now.Add(rtt*9/2))) now = now.Add(5 * rtt) require.True(t, d.ShouldSendProbe(now)) // only a single outstanding probe packet is permitted ping, _ := d.GetPing(now) require.False(t, d.ShouldSendProbe(now)) now = now.Add(5 * rtt) require.False(t, d.ShouldSendProbe(now)) ping.Handler.OnLost(ping.Frame) require.True(t, d.ShouldSendProbe(now)) } func TestMTUDiscovererAckAndLoss(t *testing.T) { const rtt = 200 * time.Millisecond rttStats := utils.NewRTTStats() rttStats.UpdateRTT(rtt, 0) d := newMTUDiscoverer(rttStats, 1000, 2000, nil) now := monotime.Now() ping, size := d.GetPing(now) require.Equal(t, protocol.ByteCount(1500), size) // the MTU is reduced if the frame is lost ping.Handler.OnLost(ping.Frame) require.Equal(t, protocol.ByteCount(1000), d.CurrentSize()) // no change to the MTU yet now = now.Add(5 * rtt) require.True(t, d.ShouldSendProbe(now)) ping, size = d.GetPing(now) require.Equal(t, protocol.ByteCount(1250), size) ping.Handler.OnAcked(ping.Frame) require.Equal(t, protocol.ByteCount(1250), d.CurrentSize()) // the MTU is increased // Even though the 1500 byte MTU probe packet was lost, we try again with a higher MTU. // This protects against regular (non-MTU-related) packet loss. now = now.Add(5 * rtt) require.True(t, d.ShouldSendProbe(now)) ping, size = d.GetPing(now) require.Greater(t, size, protocol.ByteCount(1500)) ping.Handler.OnAcked(ping.Frame) require.Equal(t, size, d.CurrentSize()) // We continue probing until the MTU is close to the maximum. var steps int oldSize := size now = now.Add(5 * rtt) for d.ShouldSendProbe(now) { ping, size = d.GetPing(now) require.Greater(t, size, oldSize) oldSize = size ping.Handler.OnAcked(ping.Frame) steps++ require.Less(t, steps, 10) now = now.Add(5 * rtt) } require.Less(t, 2000-maxMTUDiff, size) } func TestMTUDiscovererMTUDiscovery(t *testing.T) { for i := range 5 { t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { testMTUDiscovererMTUDiscovery(t) }) } } func testMTUDiscovererMTUDiscovery(t *testing.T) { const rtt = 100 * time.Millisecond const startMTU protocol.ByteCount = 1000 rttStats := utils.NewRTTStats() rttStats.UpdateRTT(rtt, 0) maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1 var eventRecorder events.Recorder d := newMTUDiscoverer(rttStats, startMTU, maxMTU, &eventRecorder) now := monotime.Now() d.Start(now) realMTU := protocol.ByteCount(rand.IntN(int(maxMTU-startMTU))) + startMTU t.Logf("MTU: %d, max: %d", realMTU, maxMTU) now = now.Add(mtuProbeDelay * rtt) var probes []protocol.ByteCount for d.ShouldSendProbe(now) { require.Less(t, len(probes), 25, fmt.Sprintf("too many iterations: %v", probes)) ping, size := d.GetPing(now) probes = append(probes, size) if size <= realMTU { ping.Handler.OnAcked(ping.Frame) } else { ping.Handler.OnLost(ping.Frame) } now = now.Add(mtuProbeDelay * rtt) } currentMTU := d.CurrentSize() diff := realMTU - currentMTU require.GreaterOrEqual(t, diff, protocol.ByteCount(0)) if maxMTU > currentMTU+maxMTU { events := eventRecorder.Events(qlog.MTUUpdated{}) require.NotEmpty(t, events) require.Equal(t, qlog.MTUUpdated{Value: int(currentMTU), Done: true}, events[0]) } t.Logf("MTU discovered: %d (diff: %d)", currentMTU, diff) t.Logf("probes sent (%d): %v", len(probes), probes) require.LessOrEqual(t, diff, maxMTUDiff) } func TestMTUDiscovererWithRandomLoss(t *testing.T) { for i := range 5 { t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { testMTUDiscovererWithRandomLoss(t) }) } } func testMTUDiscovererWithRandomLoss(t *testing.T) { const rtt = 100 * time.Millisecond const startMTU protocol.ByteCount = 1000 const maxRandomLoss = maxLostMTUProbes - 1 rttStats := utils.NewRTTStats() rttStats.SetInitialRTT(rtt) require.Equal(t, rtt, rttStats.SmoothedRTT()) maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1 var eventRecorder events.Recorder d := newMTUDiscoverer(rttStats, startMTU, maxMTU, &eventRecorder) d.Start(monotime.Now()) now := monotime.Now() realMTU := protocol.ByteCount(rand.IntN(int(maxMTU-startMTU))) + startMTU t.Logf("MTU: %d, max: %d", realMTU, maxMTU) now = now.Add(mtuProbeDelay * rtt) var probes, randomLosses []protocol.ByteCount for d.ShouldSendProbe(now) { require.Less(t, len(probes), 32, fmt.Sprintf("too many iterations: %v", probes)) ping, size := d.GetPing(now) probes = append(probes, size) packetFits := size <= realMTU var acked bool if packetFits { randomLoss := rand.IntN(maxLostMTUProbes) == 0 && len(randomLosses) < maxRandomLoss if randomLoss { randomLosses = append(randomLosses, size) } else { ping.Handler.OnAcked(ping.Frame) acked = true } } if !acked { ping.Handler.OnLost(ping.Frame) } now = now.Add(mtuProbeDelay * rtt) } currentMTU := d.CurrentSize() diff := realMTU - currentMTU require.GreaterOrEqual(t, diff, protocol.ByteCount(0)) if maxMTU > currentMTU+maxMTU { events := eventRecorder.Events(qlog.MTUUpdated{}) require.NotEmpty(t, events) require.Equal(t, qlog.MTUUpdated{Value: int(currentMTU), Done: true}, events[0]) } t.Logf("MTU discovered with random losses %v: %d (diff: %d)", randomLosses, currentMTU, diff) t.Logf("probes sent (%d): %v", len(probes), probes) require.LessOrEqual(t, diff, maxMTUDiff) } func TestMTUDiscovererReset(t *testing.T) { t.Run("probe on old path acknowledged", func(t *testing.T) { testMTUDiscovererReset(t, true) }) t.Run("probe on old path lost", func(t *testing.T) { testMTUDiscovererReset(t, false) }) } func testMTUDiscovererReset(t *testing.T, ackLastProbe bool) { const startMTU protocol.ByteCount = 1000 const maxMTU = 1400 const rtt = 100 * time.Millisecond rttStats := utils.NewRTTStats() rttStats.SetInitialRTT(rtt) now := monotime.Now() d := newMTUDiscoverer(rttStats, startMTU, maxMTU, nil) d.Start(now) ping, _ := d.GetPing(now.Add(5 * rtt)) ping.Handler.OnAcked(ping.Frame) require.Greater(t, d.CurrentSize(), startMTU) now = now.Add(5 * rtt) // send another probe packet, but neither acknowledge nor lose it before resetting ping, _ = d.GetPing(now.Add(5 * rtt)) now = now.Add(2 * rtt) // advance the timer by an arbitrary amount const newStartMTU protocol.ByteCount = 900 const newMaxMTU = 1500 d.Reset(now, newStartMTU, newMaxMTU) require.Equal(t, d.CurrentSize(), newStartMTU) // Now acknowledge / lose the probe packet. // This should be ignored, since it's on the old path. if ackLastProbe { ping.Handler.OnAcked(ping.Frame) } else { ping.Handler.OnLost(ping.Frame) } // the MTU should not have changed require.Equal(t, d.CurrentSize(), newStartMTU) // the next probe should be sent after 5 RTTs require.False(t, d.ShouldSendProbe(now.Add(5*rtt).Add(-time.Microsecond))) require.True(t, d.ShouldSendProbe(now.Add(5*rtt))) } quic-go-0.59.0/oss-fuzz.sh000066400000000000000000000027271513066070600153210ustar00rootroot00000000000000#!/bin/bash # Install Go manually, since oss-fuzz ships with an outdated Go version. # See https://github.com/google/oss-fuzz/pull/10643. export CXX="${CXX} -lresolv" # required by Go 1.20 wget https://go.dev/dl/go1.25.0.linux-amd64.tar.gz \ && mkdir temp-go \ && rm -rf /root/.go/* \ && tar -C temp-go/ -xzf go1.25.0.linux-amd64.tar.gz \ && mv temp-go/go/* /root/.go/ \ && rm -rf temp-go go1.25.0.linux-amd64.tar.gz ( # fuzz qpack compile_go_fuzzer github.com/quic-go/qpack/fuzzing Fuzz qpack_fuzzer ) ( # fuzz quic-go compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/frames Fuzz frame_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/header Fuzz header_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/transportparameters Fuzz transportparameter_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/tokens Fuzz token_fuzzer compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/handshake Fuzz handshake_fuzzer if [ $SANITIZER == "coverage" ]; then # no need for corpora if coverage exit 0 fi # generate seed corpora cd $GOPATH/src/github.com/quic-go/quic-go/ go generate -x ./fuzzing/... zip --quiet -r $OUT/header_fuzzer_seed_corpus.zip fuzzing/header/corpus zip --quiet -r $OUT/frame_fuzzer_seed_corpus.zip fuzzing/frames/corpus zip --quiet -r $OUT/transportparameter_fuzzer_seed_corpus.zip fuzzing/transportparameters/corpus zip --quiet -r $OUT/handshake_fuzzer_seed_corpus.zip fuzzing/handshake/corpus ) # for debugging ls -al $OUT quic-go-0.59.0/packet_packer.go000066400000000000000000001036511513066070600163060ustar00rootroot00000000000000package quic import ( crand "crypto/rand" "encoding/binary" "errors" "fmt" "math/rand/v2" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) var errNothingToPack = errors.New("nothing to pack") type packer interface { PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version) (*coalescedPacket, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) AppendPacket(_ *packetBuffer, maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version) (shortHeaderPacket, error) PackPTOProbePacket(_ protocol.EncryptionLevel, _ protocol.ByteCount, addPingIfEmpty bool, now monotime.Time, v protocol.Version) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) PackPathProbePacket(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) SetToken([]byte) } type sealer interface { handshake.LongHeaderSealer } type payload struct { streamFrames []ackhandler.StreamFrame frames []ackhandler.Frame ack *wire.AckFrame length protocol.ByteCount } type longHeaderPacket struct { header *wire.ExtendedHeader ack *wire.AckFrame frames []ackhandler.Frame streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets length protocol.ByteCount } type shortHeaderPacket struct { PacketNumber protocol.PacketNumber Frames []ackhandler.Frame StreamFrames []ackhandler.StreamFrame Ack *wire.AckFrame Length protocol.ByteCount IsPathMTUProbePacket bool IsPathProbePacket bool // used for logging DestConnID protocol.ConnectionID PacketNumberLen protocol.PacketNumberLen KeyPhase protocol.KeyPhaseBit } func (p *shortHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.Frames) } type coalescedPacket struct { buffer *packetBuffer longHdrPackets []*longHeaderPacket shortHdrPacket *shortHeaderPacket } // IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets). func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool { return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil } func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). switch p.header.Type { case protocol.PacketTypeInitial: return protocol.EncryptionInitial case protocol.PacketTypeHandshake: return protocol.EncryptionHandshake case protocol.PacketType0RTT: return protocol.Encryption0RTT default: panic("can't determine encryption level") } } func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } type packetNumberManager interface { PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber } type sealingManager interface { GetInitialSealer() (handshake.LongHeaderSealer, error) GetHandshakeSealer() (handshake.LongHeaderSealer, error) Get0RTTSealer() (handshake.LongHeaderSealer, error) Get1RTTSealer() (handshake.ShortHeaderSealer, error) } type frameSource interface { HasData() bool Append([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount, monotime.Time, protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) } type ackFrameSource interface { GetAckFrame(_ protocol.EncryptionLevel, now monotime.Time, onlyIfQueued bool) *wire.AckFrame } type packetPacker struct { srcConnID protocol.ConnectionID getDestConnID func() protocol.ConnectionID perspective protocol.Perspective cryptoSetup sealingManager initialStream *initialCryptoStream handshakeStream *cryptoStream token []byte pnManager packetNumberManager framer frameSource acks ackFrameSource datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue rand rand.Rand numNonAckElicitingAcks int } var _ packer = &packetPacker{} func newPacketPacker( srcConnID protocol.ConnectionID, getDestConnID func() protocol.ConnectionID, initialStream *initialCryptoStream, handshakeStream *cryptoStream, packetNumberManager packetNumberManager, retransmissionQueue *retransmissionQueue, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, datagramQueue *datagramQueue, perspective protocol.Perspective, ) *packetPacker { var b [16]byte _, _ = crand.Read(b[:]) return &packetPacker{ cryptoSetup: cryptoSetup, getDestConnID: getDestConnID, srcConnID: srcConnID, initialStream: initialStream, handshakeStream: handshakeStream, retransmissionQueue: retransmissionQueue, datagramQueue: datagramQueue, perspective: perspective, framer: framer, acks: acks, rand: *rand.New(rand.NewPCG(binary.BigEndian.Uint64(b[:8]), binary.BigEndian.Uint64(b[8:]))), pnManager: packetNumberManager, } } // PackConnectionClose packs a packet that closes the connection with a transport error. func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { var reason string // don't send details of crypto errors if !e.ErrorCode.IsCryptoError() { reason = e.ErrorMessage } return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v) } // PackApplicationClose packs a packet that closes the connection with an application error. func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v) } func (p *packetPacker) packConnectionClose( isApplicationError bool, errorCode uint64, frameType uint64, reason string, maxPacketSize protocol.ByteCount, v protocol.Version, ) (*coalescedPacket, error) { var sealers [4]sealer var hdrs [3]*wire.ExtendedHeader var payloads [4]payload var size protocol.ByteCount var connID protocol.ConnectionID var oneRTTPacketNumber protocol.PacketNumber var oneRTTPacketNumberLen protocol.PacketNumberLen var keyPhase protocol.KeyPhaseBit // only set for 1-RTT var numLongHdrPackets uint8 encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} for i, encLevel := range encLevels { if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { continue } ccf := &wire.ConnectionCloseFrame{ IsApplicationError: isApplicationError, ErrorCode: errorCode, FrameType: frameType, ReasonPhrase: reason, } // don't send application errors in Initial or Handshake packets if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { ccf.IsApplicationError = false ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) ccf.ReasonPhrase = "" } pl := payload{ frames: []ackhandler.Frame{{Frame: ccf}}, length: ccf.Length(v), } var sealer sealer var err error switch encLevel { case protocol.EncryptionInitial: sealer, err = p.cryptoSetup.GetInitialSealer() case protocol.EncryptionHandshake: sealer, err = p.cryptoSetup.GetHandshakeSealer() case protocol.Encryption0RTT: sealer, err = p.cryptoSetup.Get0RTTSealer() case protocol.Encryption1RTT: var s handshake.ShortHeaderSealer s, err = p.cryptoSetup.Get1RTTSealer() if err == nil { keyPhase = s.KeyPhase() } sealer = s } if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped { continue } if err != nil { return nil, err } sealers[i] = sealer var hdr *wire.ExtendedHeader if encLevel == protocol.Encryption1RTT { connID = p.getDestConnID() oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, pl) } else { hdr = p.getLongHeader(encLevel, v) hdrs[i] = hdr size += p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) numLongHdrPackets++ } payloads[i] = pl } buffer := getPacketBuffer() packet := &coalescedPacket{ buffer: buffer, longHdrPackets: make([]*longHeaderPacket, 0, numLongHdrPackets), } for i, encLevel := range encLevels { if sealers[i] == nil { continue } if encLevel == protocol.Encryption1RTT { shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], 0, maxPacketSize, sealers[i], false, v) if err != nil { return nil, err } packet.shortHdrPacket = &shp } else { var paddingLen protocol.ByteCount if encLevel == protocol.EncryptionInitial { paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) if err != nil { return nil, err } packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } } return packet, nil } // longHeaderPacketLength calculates the length of a serialized long header packet. // It takes into account that packets that have a tiny payload need to be padded, // such that len(payload) + packet number len >= 4 + AEAD overhead func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(hdr.PacketNumberLen) if pl.length < 4-pnLen { paddingLen = 4 - pnLen - pl.length } return hdr.GetLength(v) + pl.length + paddingLen } // shortHeaderPacketLength calculates the length of a serialized short header packet. // It takes into account that packets that have a tiny payload need to be padded, // such that len(payload) + packet number len >= 4 + AEAD overhead func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, pl payload) protocol.ByteCount { var paddingLen protocol.ByteCount if pl.length < 4-protocol.ByteCount(pnLen) { paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length } return wire.ShortHeaderLen(connID, pnLen) + pl.length + paddingLen } // size is the expected size of the packet, if no padding was applied. func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount { // For the server, only ack-eliciting Initial packets need to be padded. if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { return 0 } if currentSize >= maxPacketSize { return 0 } return maxPacketSize - currentSize } // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxSize protocol.ByteCount, now monotime.Time, v protocol.Version) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload oneRTTPacketNumber protocol.PacketNumber oneRTTPacketNumberLen protocol.PacketNumberLen ) // Try packing an Initial packet. initialSealer, err := p.cryptoSetup.GetInitialSealer() if err != nil && err != handshake.ErrKeysDropped { return nil, err } var size protocol.ByteCount if initialSealer != nil { initialHdr, initialPayload = p.maybeGetCryptoPacket( maxSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, now, false, onlyAck, v, ) if initialPayload.length > 0 { size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) } } // Add a Handshake packet. var handshakeSealer sealer if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) { var err error handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } if handshakeSealer != nil { handshakeHdr, handshakePayload = p.maybeGetCryptoPacket( maxSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, now, false, onlyAck, v, ) if handshakePayload.length > 0 { s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead()) size += s } } } // Add a 0-RTT / 1-RTT packet. var zeroRTTSealer sealer var oneRTTSealer handshake.ShortHeaderSealer var connID protocol.ConnectionID var kp protocol.KeyPhaseBit if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) { var err error oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } if err == nil { // 1-RTT kp = oneRTTSealer.KeyPhase() connID = p.getDestConnID() oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxSize-size, onlyAck, now, v) if oneRTTPayload.length > 0 { size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) } } else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames var err error zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } if zeroRTTSealer != nil { zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxSize-size, now, v) if zeroRTTPayload.length > 0 { size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead()) } } } } if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 { return nil, nil } buffer := getPacketBuffer() packet := &coalescedPacket{ buffer: buffer, longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload.length > 0 { padding := p.initialPaddingLen(initialPayload.frames, size, maxSize) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) if err != nil { return nil, err } packet.longHdrPackets = append(packet.longHdrPackets, cont) } if handshakePayload.length > 0 { cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) if err != nil { return nil, err } packet.longHdrPackets = append(packet.longHdrPackets, cont) } if zeroRTTPayload.length > 0 { longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v) if err != nil { return nil, err } packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } else if oneRTTPayload.length > 0 { shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxSize, oneRTTSealer, false, v) if err != nil { return nil, err } packet.shortHdrPacket = &shp } return packet, nil } // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // It should be called after the handshake is confirmed. func (p *packetPacker) PackAckOnlyPacket(maxSize protocol.ByteCount, now monotime.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { buf := getPacketBuffer() packet, err := p.appendPacket(buf, true, maxSize, now, v) return packet, buf, err } // AppendPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. func (p *packetPacker) AppendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now monotime.Time, v protocol.Version) (shortHeaderPacket, error) { return p.appendPacket(buf, false, maxSize, now, v) } func (p *packetPacker) appendPacket( buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, now monotime.Time, v protocol.Version, ) (shortHeaderPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, err } pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) connID := p.getDestConnID() hdrLen := wire.ShortHeaderLen(connID, pnLen) pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, now, v) if pl.length == 0 { return shortHeaderPacket{}, errNothingToPack } kp := sealer.KeyPhase() return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) } func (p *packetPacker) maybeGetCryptoPacket( maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, now monotime.Time, addPingIfEmpty bool, onlyAck bool, v protocol.Version, ) (*wire.ExtendedHeader, payload) { if onlyAck { if ack := p.acks.GetAckFrame(encLevel, now, true); ack != nil { hdr := p.getLongHeader(encLevel, v) maxPacketSize -= hdr.GetLength(v) ack.Truncate(maxPacketSize, v) return hdr, payload{ack: ack, length: ack.Length(v)} } return nil, payload{length: 0} } var hasCryptoData func() bool var popCryptoFrame func(maxLen protocol.ByteCount) *wire.CryptoFrame //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. switch encLevel { case protocol.EncryptionInitial: hasCryptoData = p.initialStream.HasData popCryptoFrame = p.initialStream.PopCryptoFrame case protocol.EncryptionHandshake: hasCryptoData = p.handshakeStream.HasData popCryptoFrame = p.handshakeStream.PopCryptoFrame } handler := p.retransmissionQueue.AckHandler(encLevel) hasRetransmission := p.retransmissionQueue.HasData(encLevel) ack := p.acks.GetAckFrame(encLevel, now, !hasRetransmission && !hasCryptoData()) var pl payload if !hasCryptoData() && !hasRetransmission && ack == nil { if !addPingIfEmpty { // nothing to send return nil, payload{} } ping := &wire.PingFrame{} pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, Handler: emptyHandler{}}) pl.length += ping.Length(v) } hdr := p.getLongHeader(encLevel, v) maxPacketSize -= hdr.GetLength(v) if ack != nil { ack.Truncate(maxPacketSize, v) pl.ack = ack pl.length = ack.Length(v) maxPacketSize -= pl.length } if hasRetransmission { for { frame := p.retransmissionQueue.GetFrame(encLevel, maxPacketSize, v) if frame == nil { break } pl.frames = append(pl.frames, ackhandler.Frame{ Frame: frame, Handler: p.retransmissionQueue.AckHandler(encLevel), }) frameLen := frame.Length(v) pl.length += frameLen maxPacketSize -= frameLen } return hdr, pl } else { for hasCryptoData() { cf := popCryptoFrame(maxPacketSize) if cf == nil { break } pl.frames = append(pl.frames, ackhandler.Frame{Frame: cf, Handler: handler}) pl.length += cf.Length(v) maxPacketSize -= cf.Length(v) } } return hdr, pl } func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxSize protocol.ByteCount, now monotime.Time, v protocol.Version) (*wire.ExtendedHeader, payload) { if p.perspective != protocol.PerspectiveClient { return nil, payload{} } hdr := p.getLongHeader(protocol.Encryption0RTT, v) maxPayloadSize := maxSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead()) return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, now, v) } func (p *packetPacker) maybeGetShortHeaderPacket( sealer handshake.ShortHeaderSealer, hdrLen, maxPacketSize protocol.ByteCount, onlyAck bool, now monotime.Time, v protocol.Version, ) payload { maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, true, now, v) } func (p *packetPacker) maybeGetAppDataPacket( maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, now monotime.Time, v protocol.Version, ) payload { pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, now, v) // check if we have anything to send if len(pl.frames) == 0 && len(pl.streamFrames) == 0 { if pl.ack == nil { return payload{} } // the packet only contains an ACK if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping}) pl.length += ping.Length(v) p.numNonAckElicitingAcks = 0 } else { p.numNonAckElicitingAcks++ } } else { p.numNonAckElicitingAcks = 0 } return pl } func (p *packetPacker) composeNextPacket( maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, now monotime.Time, v protocol.Version, ) payload { if onlyAck { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, true); ack != nil { ack.Truncate(maxPayloadSize, v) return payload{ack: ack, length: ack.Length(v)} } return payload{} } hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasData(protocol.Encryption1RTT) var pl payload if ackAllowed { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, !hasRetransmission && !hasData); ack != nil { ack.Truncate(maxPayloadSize, v) pl.ack = ack pl.length += ack.Length(v) } } if p.datagramQueue != nil { if f := p.datagramQueue.Peek(); f != nil { size := f.Length(v) if size <= maxPayloadSize-pl.length { // DATAGRAM frame fits pl.frames = append(pl.frames, ackhandler.Frame{Frame: f}) pl.length += size p.datagramQueue.Pop() } else if pl.ack == nil { // The DATAGRAM frame doesn't fit, and the packet doesn't contain an ACK. // Discard this frame. There's no point in retrying this in the next packet, // as it's unlikely that the available packet size will increase. p.datagramQueue.Pop() } // If the DATAGRAM frame was too large and the packet contained an ACK, we'll try to send it out later. } } if pl.ack != nil && !hasData && !hasRetransmission { return pl } if hasRetransmission { for { remainingLen := maxPayloadSize - pl.length if remainingLen < protocol.MinStreamFrameSize { break } f := p.retransmissionQueue.GetFrame(protocol.Encryption1RTT, remainingLen, v) if f == nil { break } pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AckHandler(protocol.Encryption1RTT)}) pl.length += f.Length(v) } } if hasData { var lengthAdded protocol.ByteCount startLen := len(pl.frames) pl.frames, pl.streamFrames, lengthAdded = p.framer.Append(pl.frames, pl.streamFrames, maxPayloadSize-pl.length, now, v) pl.length += lengthAdded // add handlers for the control frames that were added for i := startLen; i < len(pl.frames); i++ { if pl.frames[i].Handler != nil { continue } switch pl.frames[i].Frame.(type) { case *wire.PathChallengeFrame, *wire.PathResponseFrame: // Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet. // PATH_CHALLENGE and PATH_RESPONSE are never retransmitted. default: // we might be packing a 0-RTT packet, but we need to use the 1-RTT ack handler anyway pl.frames[i].Handler = p.retransmissionQueue.AckHandler(protocol.Encryption1RTT) } } } return pl } func (p *packetPacker) PackPTOProbePacket( encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, addPingIfEmpty bool, now monotime.Time, v protocol.Version, ) (*coalescedPacket, error) { if encLevel == protocol.Encryption1RTT { return p.packPTOProbePacket1RTT(maxPacketSize, addPingIfEmpty, now, v) } var sealer handshake.LongHeaderSealer //nolint:exhaustive // Probe packets are never sent for 0-RTT. switch encLevel { case protocol.EncryptionInitial: var err error sealer, err = p.cryptoSetup.GetInitialSealer() if err != nil { return nil, err } case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil { return nil, err } default: panic("unknown encryption level") } hdr, pl := p.maybeGetCryptoPacket( maxPacketSize-protocol.ByteCount(sealer.Overhead()), encLevel, now, addPingIfEmpty, false, v, ) if pl.length == 0 { return nil, nil } buffer := getPacketBuffer() packet := &coalescedPacket{buffer: buffer} size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) var padding protocol.ByteCount if encLevel == protocol.EncryptionInitial { padding = p.initialPaddingLen(pl.frames, size, maxPacketSize) } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v) if err != nil { return nil, err } packet.longHdrPackets = []*longHeaderPacket{longHdrPacket} return packet, nil } func (p *packetPacker) packPTOProbePacket1RTT(maxPacketSize protocol.ByteCount, addPingIfEmpty bool, now monotime.Time, v protocol.Version) (*coalescedPacket, error) { s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return nil, err } kp := s.KeyPhase() connID := p.getDestConnID() pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdrLen := wire.ShortHeaderLen(connID, pnLen) pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, now, v) if pl.length == 0 { if !addPingIfEmpty { return nil, nil } ping := &wire.PingFrame{} pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, Handler: emptyHandler{}}) pl.length += ping.Length(v) } buffer := getPacketBuffer() packet := &coalescedPacket{buffer: buffer} shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v) if err != nil { return nil, err } packet.shortHdrPacket = &shp return packet, nil } func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { pl := payload{ frames: []ackhandler.Frame{ping}, length: ping.Frame.Length(v), } buffer := getPacketBuffer() s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, nil, err } connID := p.getDestConnID() pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead()) kp := s.KeyPhase() packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v) return packet, buffer, err } func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, frames []ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) buf := getPacketBuffer() s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, nil, err } var l protocol.ByteCount for _, f := range frames { l += f.Frame.Length(v) } payload := payload{ frames: frames, length: l, } padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead()) packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v) if err != nil { return shortHeaderPacket{}, nil, err } packet.IsPathProbePacket = true return packet, buf, err } func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) hdr := &wire.ExtendedHeader{ PacketNumber: pn, PacketNumberLen: pnLen, } hdr.Version = v hdr.SrcConnectionID = p.srcConnID hdr.DestConnectionID = p.getDestConnID() //nolint:exhaustive // 1-RTT packets are not long header packets. switch encLevel { case protocol.EncryptionInitial: hdr.Type = protocol.PacketTypeInitial hdr.Token = p.token case protocol.EncryptionHandshake: hdr.Type = protocol.PacketTypeHandshake case protocol.Encryption0RTT: hdr.Type = protocol.PacketType0RTT } return hdr } func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if pl.length < 4-pnLen { paddingLen = 4 - pnLen - pl.length } paddingLen += padding header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen startLen := len(buffer.Data) raw := buffer.Data[startLen:] raw, err := header.Append(raw, v) if err != nil { return nil, err } payloadOffset := protocol.ByteCount(len(raw)) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) if err != nil { return nil, err } raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) } return &longHeaderPacket{ header: header, ack: pl.ack, frames: pl.frames, streamFrames: pl.streamFrames, length: protocol.ByteCount(len(raw)), }, nil } func (p *packetPacker) appendShortHeaderPacket( buffer *packetBuffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, pl payload, padding, maxPacketSize protocol.ByteCount, sealer sealer, isMTUProbePacket bool, v protocol.Version, ) (shortHeaderPacket, error) { var paddingLen protocol.ByteCount if pl.length < 4-protocol.ByteCount(pnLen) { paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length } paddingLen += padding startLen := len(buffer.Data) raw := buffer.Data[startLen:] raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) if err != nil { return shortHeaderPacket{}, err } payloadOffset := protocol.ByteCount(len(raw)) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) if err != nil { return shortHeaderPacket{}, err } if !isMTUProbePacket { if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize { return shortHeaderPacket{}, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize) } } raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn { return shortHeaderPacket{}, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN) } return shortHeaderPacket{ PacketNumber: pn, PacketNumberLen: pnLen, KeyPhase: kp, StreamFrames: pl.streamFrames, Frames: pl.frames, Ack: pl.ack, Length: protocol.ByteCount(len(raw)), DestConnID: connID, IsPathMTUProbePacket: isMTUProbePacket, }, nil } // appendPacketPayload serializes the payload of a packet into the raw byte slice. // It modifies the order of payload.frames. func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) { payloadOffset := len(raw) if pl.ack != nil { var err error raw, err = pl.ack.Append(raw, v) if err != nil { return nil, err } } if paddingLen > 0 { raw = append(raw, make([]byte, paddingLen)...) } // Randomize the order of the control frames. // This makes sure that the receiver doesn't rely on the order in which frames are packed. if len(pl.frames) > 1 { p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] }) } for _, f := range pl.frames { var err error raw, err = f.Frame.Append(raw, v) if err != nil { return nil, err } } for _, f := range pl.streamFrames { var err error raw, err = f.Frame.Append(raw, v) if err != nil { return nil, err } } if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length { return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize) } return raw, nil } func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte { _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset]) raw = raw[:len(raw)+sealer.Overhead()] // apply header protection pnOffset := payloadOffset - pnLen sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset]) return raw } func (p *packetPacker) SetToken(token []byte) { p.token = token } type emptyHandler struct{} var _ ackhandler.FrameHandler = emptyHandler{} func (emptyHandler) OnAcked(wire.Frame) {} func (emptyHandler) OnLost(wire.Frame) {} quic-go-0.59.0/packet_packer_test.go000066400000000000000000001442731513066070600173520ustar00rootroot00000000000000package quic import ( "bytes" "crypto/rand" "errors" "testing" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/mocks" mockackhandler "github.com/quic-go/quic-go/internal/mocks/ackhandler" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) const testPackerConnIDLen = 4 type testPacketPacker struct { packer *packetPacker initialStream *initialCryptoStream handshakeStream *cryptoStream datagramQueue *datagramQueue pnManager *mockackhandler.MockSentPacketHandler sealingManager *MockSealingManager framer *MockFrameSource ackFramer *MockAckFrameSource retransmissionQueue *retransmissionQueue } func newTestPacketPacker(t *testing.T, mockCtrl *gomock.Controller, pers protocol.Perspective) *testPacketPacker { destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) require.Equal(t, testPackerConnIDLen, destConnID.Len()) initialStream := newInitialCryptoStream(pers == protocol.PerspectiveClient) handshakeStream := newCryptoStream() pnManager := mockackhandler.NewMockSentPacketHandler(mockCtrl) framer := NewMockFrameSource(mockCtrl) ackFramer := NewMockAckFrameSource(mockCtrl) sealingManager := NewMockSealingManager(mockCtrl) datagramQueue := newDatagramQueue(func() {}, utils.DefaultLogger) retransmissionQueue := newRetransmissionQueue() return &testPacketPacker{ pnManager: pnManager, initialStream: initialStream, handshakeStream: handshakeStream, sealingManager: sealingManager, framer: framer, ackFramer: ackFramer, datagramQueue: datagramQueue, retransmissionQueue: retransmissionQueue, packer: newPacketPacker( protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), func() protocol.ConnectionID { return destConnID }, initialStream, handshakeStream, pnManager, retransmissionQueue, sealingManager, framer, ackFramer, datagramQueue, pers, ), } } // newMockShortHeaderSealer returns a mock short header sealer that seals a short header packet func newMockShortHeaderSealer(mockCtrl *gomock.Controller) *mocks.MockShortHeaderSealer { sealer := mocks.NewMockShortHeaderSealer(mockCtrl) sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne).AnyTimes() sealer.EXPECT().Overhead().Return(7).AnyTimes() sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { return append(src, bytes.Repeat([]byte{'s'}, sealer.Overhead())...) }).AnyTimes() return sealer } func parsePacket(t *testing.T, data []byte) (hdrs []*wire.ExtendedHeader, more []byte) { t.Helper() for len(data) > 0 { if !wire.IsLongHeaderPacket(data[0]) { break } hdr, _, more, err := wire.ParsePacket(data) require.NoError(t, err) extHdr, err := hdr.ParseExtended(data) require.NoError(t, err) require.GreaterOrEqual(t, extHdr.Length+protocol.ByteCount(extHdr.PacketNumberLen), protocol.ByteCount(4)) data = more hdrs = append(hdrs, extHdr) } return hdrs, data } func parseShortHeaderPacket(t *testing.T, data []byte, connIDLen int) { t.Helper() l, _, pnLen, _, err := wire.ParseShortHeader(data, connIDLen) require.NoError(t, err) require.GreaterOrEqual(t, len(data)-l+int(pnLen), 4) } func expectAppendFrames(framer *MockFrameSource, controlFrames []ackhandler.Frame, streamFrames []ackhandler.StreamFrame) { framer.EXPECT().Append(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(cf []ackhandler.Frame, sf []ackhandler.StreamFrame, maxSize protocol.ByteCount, _ monotime.Time, v protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) { var length protocol.ByteCount for _, f := range controlFrames { if length+f.Frame.Length(v) > maxSize { break } length += f.Frame.Length(v) cf = append(cf, f) } for _, f := range streamFrames { if length+f.Frame.Length(v) > maxSize { break } length += f.Frame.Length(v) sf = append(sf, f) } return cf, sf, length }, ) } func generateLargeACKFrame(t *testing.T, minSize protocol.ByteCount) *wire.AckFrame { t.Helper() ack := &wire.AckFrame{ AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}, DelayTime: 42 * time.Millisecond, } var counter int for ack.Length(protocol.Version1) < minSize { counter++ if counter > protocol.MaxNumAckRanges { t.Fatalf("max number of ACK ranges reached, size: %d", ack.Length(protocol.Version1)) } pn := protocol.PacketNumber(1000 * counter) ack.AckRanges = append([]wire.AckRange{{Smallest: pn, Largest: pn + 100}}, ack.AckRanges...) } return ack } func TestPackLongHeaders(t *testing.T) { skipIfDisableScramblingEnvSet(t) t.Run("with Handshake ACK", func(t *testing.T) { testPackLongHeaders(t, true) }) t.Run("without Handshake ACK", func(t *testing.T) { testPackLongHeaders(t, false) }) } func testPackLongHeaders(t *testing.T, includeACK bool) { const maxPacketSize protocol.ByteCount = 1234 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveClient) token := make([]byte, 20) rand.Read(token) tp.packer.SetToken(token) now := monotime.Now() tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen3) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen4) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, now, false) var numRanges int if includeACK { ack := generateLargeACKFrame(t, maxPacketSize-1000) numRanges = len(ack.AckRanges) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, now, false).Return(ack) } else { tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, now, false) tp.sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) } clientHello := getClientHello(t, "quic-go.net") tp.initialStream.Write(clientHello) tp.initialStream.Write(make([]byte, 900-len(clientHello))) // add some more data tp.packer.retransmissionQueue.addHandshake(&wire.PingFrame{}) p, err := tp.packer.PackCoalescedPacket(false, maxPacketSize, now, protocol.Version1) require.NoError(t, err) require.Equal(t, maxPacketSize, p.buffer.Len()) require.Len(t, p.longHdrPackets, 2) require.Nil(t, p.shortHdrPacket) require.Equal(t, protocol.EncryptionInitial, p.longHdrPackets[0].EncryptionLevel()) // the ClientHello is split into multiple frames require.GreaterOrEqual(t, len(p.longHdrPackets[0].frames), 3) for _, f := range p.longHdrPackets[0].frames { require.IsType(t, &wire.CryptoFrame{}, f.Frame) } require.Equal(t, protocol.EncryptionHandshake, p.longHdrPackets[1].EncryptionLevel()) require.Len(t, p.longHdrPackets[1].frames, 1) require.IsType(t, &wire.PingFrame{}, p.longHdrPackets[1].frames[0].Frame) if includeACK { require.NotNil(t, p.longHdrPackets[1].ack) // the ACK frame was truncated require.Less(t, len(p.longHdrPackets[1].ack.AckRanges), numRanges) } else { require.Nil(t, p.longHdrPackets[1].ack) } hdrs, more := parsePacket(t, p.buffer.Data) require.Len(t, hdrs, 2) require.Equal(t, protocol.PacketTypeInitial, hdrs[0].Type) require.Equal(t, token, hdrs[0].Token) require.Equal(t, protocol.PacketNumber(0x24), hdrs[0].PacketNumber) require.Equal(t, protocol.PacketNumberLen3, hdrs[0].PacketNumberLen) require.Equal(t, protocol.PacketTypeHandshake, hdrs[1].Type) require.Nil(t, hdrs[1].Token) require.Equal(t, protocol.PacketNumber(0x42), hdrs[1].PacketNumber) require.Equal(t, protocol.PacketNumberLen4, hdrs[1].PacketNumberLen) require.Empty(t, more) } func TestPackCoalescedAckOnlyPacketNothingToSend(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveClient) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) // the packet number is not popped tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any(), true) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, gomock.Any(), true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true) p, err := tp.packer.PackCoalescedPacket(true, 1234, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Nil(t, p) } func TestPackInitialAckOnlyPacket(t *testing.T) { t.Run("client", func(t *testing.T) { testPackInitialAckOnlyPacket(t, protocol.PerspectiveClient) }) t.Run("server", func(t *testing.T) { testPackInitialAckOnlyPacket(t, protocol.PerspectiveServer) }) } func testPackInitialAckOnlyPacket(t *testing.T, pers protocol.Perspective) { const maxPacketSize protocol.ByteCount = 1234 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, pers) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any(), true).Return(ack) p, err := tp.packer.PackCoalescedPacket(true, maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p) require.Len(t, p.longHdrPackets, 1) require.Equal(t, protocol.EncryptionInitial, p.longHdrPackets[0].EncryptionLevel()) require.Equal(t, ack, p.longHdrPackets[0].ack) require.Empty(t, p.longHdrPackets[0].frames) // only the client needs to pad Initial packets switch pers { case protocol.PerspectiveClient: require.Equal(t, maxPacketSize, p.buffer.Len()) case protocol.PerspectiveServer: require.Less(t, p.buffer.Len(), protocol.ByteCount(100)) } hdrs, more := parsePacket(t, p.buffer.Data) require.Empty(t, more) require.Len(t, hdrs, 1) require.Equal(t, protocol.PacketTypeInitial, hdrs[0].Type) } func TestPack1RTTAckOnlyPacket(t *testing.T) { const maxPacketSize protocol.ByteCount = 1300 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveClient) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true).Return(ack) p, buffer, err := tp.packer.PackAckOnlyPacket(maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Equal(t, ack, p.Ack) require.Empty(t, p.Frames) parsePacket(t, buffer.Data) } func TestPack0RTTPacket(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveClient) tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().Get0RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any(), true) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{MaximumData: 0x1337}} tp.framer.EXPECT().HasData().Return(true) // TODO: check sizes tp.framer.EXPECT().Append(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(fs []ackhandler.Frame, sf []ackhandler.StreamFrame, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) { return append(fs, cf), sf, cf.Frame.Length(protocol.Version1) }, ) p, err := tp.packer.PackCoalescedPacket(false, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p) require.Len(t, p.longHdrPackets, 1) require.Equal(t, protocol.PacketType0RTT, p.longHdrPackets[0].header.Type) require.Equal(t, protocol.Encryption0RTT, p.longHdrPackets[0].EncryptionLevel()) require.Len(t, p.longHdrPackets[0].frames, 1) require.Equal(t, cf.Frame, p.longHdrPackets[0].frames[0].Frame) require.NotNil(t, p.longHdrPackets[0].frames[0].Handler) } // ACK frames can't be sent in 0-RTT packets func TestPack0RTTPacketNoACK(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveClient) tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any(), true) // no further calls to get an ACK frame p, err := tp.packer.PackCoalescedPacket(true, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Nil(t, p) } func TestPackCoalescedAppData(t *testing.T) { t.Run("with large ACK", func(t *testing.T) { testPackCoalescedAppData(t, true) }) t.Run("without ACK", func(t *testing.T) { testPackCoalescedAppData(t, false) }) } func testPackCoalescedAppData(t *testing.T, withAck bool) { const maxPacketSize protocol.ByteCount = 1234 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, gomock.Any(), false) var numRanges int if withAck { // The ACK is too large and needs to be truncated ack := generateLargeACKFrame(t, maxPacketSize-1000) numRanges = len(ack.AckRanges) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false).Return(ack) } else { tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false) } handshakeData := make([]byte, 1000) rand.Read(handshakeData) tp.handshakeStream.Write(handshakeData) expectAppendFrames(tp.framer, nil, []ackhandler.StreamFrame{{Frame: &wire.StreamFrame{Data: []byte("foobar")}}}) p, err := tp.packer.PackCoalescedPacket(false, maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Len(t, p.longHdrPackets, 1) require.Equal(t, protocol.EncryptionHandshake, p.longHdrPackets[0].EncryptionLevel()) require.Len(t, p.longHdrPackets[0].frames, 1) require.Equal(t, handshakeData, p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data) require.NotNil(t, p.shortHdrPacket) require.Empty(t, p.shortHdrPacket.Frames) if withAck { require.NotNil(t, p.shortHdrPacket.Ack) require.Less(t, len(p.shortHdrPacket.Ack.AckRanges), numRanges) require.LessOrEqual(t, len(p.buffer.Data), int(maxPacketSize)) require.Empty(t, p.shortHdrPacket.StreamFrames) } else { require.Nil(t, p.shortHdrPacket.Ack) require.Less(t, len(p.buffer.Data), int(maxPacketSize)) require.Len(t, p.shortHdrPacket.StreamFrames, 1) require.Equal(t, []byte("foobar"), p.shortHdrPacket.StreamFrames[0].Frame.Data) } hdrs, more := parsePacket(t, p.buffer.Data) require.Len(t, hdrs, 1) require.Equal(t, protocol.PacketTypeHandshake, hdrs[0].Type) require.NotEmpty(t, more) parseShortHeaderPacket(t, more, testPackerConnIDLen) } func TestPackConnectionCloseCoalesced(t *testing.T) { t.Run("client", func(t *testing.T) { testPackConnectionCloseCoalesced(t, protocol.PerspectiveClient) }) t.Run("server", func(t *testing.T) { testPackConnectionCloseCoalesced(t, protocol.PerspectiveServer) }) } func testPackConnectionCloseCoalesced(t *testing.T, pers protocol.Perspective) { const maxPacketSize protocol.ByteCount = 1234 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, pers) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(1)) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(2), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(2)) tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) switch pers { case protocol.PerspectiveClient: tp.sealingManager.EXPECT().Get0RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(3), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(3)) case protocol.PerspectiveServer: tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(3), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(3)) } p, err := tp.packer.PackApplicationClose(&qerr.ApplicationError{ ErrorCode: 0x1337, ErrorMessage: "foobar", }, maxPacketSize, protocol.Version1) require.NoError(t, err) switch pers { case protocol.PerspectiveClient: require.Len(t, p.longHdrPackets, 3) require.Nil(t, p.shortHdrPacket) case protocol.PerspectiveServer: require.Len(t, p.longHdrPackets, 2) require.NotNil(t, p.shortHdrPacket) } // for Initial packets, the error code is replace with a transport error of type APPLICATION_ERROR require.Equal(t, protocol.PacketTypeInitial, p.longHdrPackets[0].header.Type) require.Equal(t, protocol.PacketNumber(1), p.longHdrPackets[0].header.PacketNumber) require.Len(t, p.longHdrPackets[0].frames, 1) require.IsType(t, &wire.ConnectionCloseFrame{}, p.longHdrPackets[0].frames[0].Frame) ccf := p.longHdrPackets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) require.False(t, ccf.IsApplicationError) require.Equal(t, uint64(qerr.ApplicationErrorErrorCode), ccf.ErrorCode) require.Empty(t, ccf.ReasonPhrase) // for Handshake packets, the error code is replace with a transport error of type APPLICATION_ERROR require.Equal(t, protocol.PacketTypeHandshake, p.longHdrPackets[1].header.Type) require.Equal(t, protocol.PacketNumber(2), p.longHdrPackets[1].header.PacketNumber) require.Len(t, p.longHdrPackets[1].frames, 1) require.IsType(t, &wire.ConnectionCloseFrame{}, p.longHdrPackets[1].frames[0].Frame) ccf = p.longHdrPackets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) require.False(t, ccf.IsApplicationError) require.Equal(t, uint64(qerr.ApplicationErrorErrorCode), ccf.ErrorCode) require.Empty(t, ccf.ReasonPhrase) // for application-data packet number space (1-RTT for the server, 0-RTT for the client), // the application-level error code is sent switch pers { case protocol.PerspectiveClient: require.Equal(t, protocol.PacketNumber(3), p.longHdrPackets[2].header.PacketNumber) require.Len(t, p.longHdrPackets[2].frames, 1) require.IsType(t, &wire.ConnectionCloseFrame{}, p.longHdrPackets[2].frames[0].Frame) ccf = p.longHdrPackets[2].frames[0].Frame.(*wire.ConnectionCloseFrame) case protocol.PerspectiveServer: require.Equal(t, protocol.PacketNumber(3), p.shortHdrPacket.PacketNumber) require.Len(t, p.shortHdrPacket.Frames, 1) require.IsType(t, &wire.ConnectionCloseFrame{}, p.shortHdrPacket.Frames[0].Frame) ccf = p.shortHdrPacket.Frames[0].Frame.(*wire.ConnectionCloseFrame) } require.True(t, ccf.IsApplicationError) require.Equal(t, uint64(0x1337), ccf.ErrorCode) require.Equal(t, "foobar", ccf.ReasonPhrase) // the client needs to pad this packet to the max packet size switch pers { case protocol.PerspectiveClient: require.Equal(t, maxPacketSize, p.buffer.Len()) case protocol.PerspectiveServer: require.Less(t, p.buffer.Len(), protocol.ByteCount(100)) } } func TestPackConnectionCloseCryptoError(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) quicErr := qerr.NewLocalCryptoError(0x42, errors.New("crypto error")) quicErr.FrameType = 0x1234 p, err := tp.packer.PackConnectionClose(quicErr, protocol.MaxByteCount, protocol.Version1) require.NoError(t, err) require.Len(t, p.longHdrPackets, 1) require.Equal(t, protocol.PacketTypeHandshake, p.longHdrPackets[0].header.Type) require.Len(t, p.longHdrPackets[0].frames, 1) require.IsType(t, &wire.ConnectionCloseFrame{}, p.longHdrPackets[0].frames[0].Frame) ccf := p.longHdrPackets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) require.False(t, ccf.IsApplicationError) require.Equal(t, uint64(0x100+0x42), ccf.ErrorCode) require.Equal(t, uint64(0x1234), ccf.FrameType) // for crypto errors, the reason phrase is cleared require.Empty(t, ccf.ReasonPhrase) } func TestPackConnectionClose1RTT(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) // expect no framer.PopStreamFrames p, err := tp.packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.CryptoBufferExceeded, ErrorMessage: "foo", }, protocol.MaxByteCount, protocol.Version1) require.NoError(t, err) require.Empty(t, p.longHdrPackets) require.Len(t, p.shortHdrPacket.Frames, 1) require.IsType(t, &wire.ConnectionCloseFrame{}, p.shortHdrPacket.Frames[0].Frame) ccf := p.shortHdrPacket.Frames[0].Frame.(*wire.ConnectionCloseFrame) require.False(t, ccf.IsApplicationError) require.Equal(t, uint64(qerr.CryptoBufferExceeded), ccf.ErrorCode) require.Equal(t, "foo", ccf.ReasonPhrase) } func TestPack1RTTPacketNothingToSend(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) // don't expect any calls to PopPacketNumber tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true) tp.framer.EXPECT().HasData() _, err := tp.packer.AppendPacket(getPacketBuffer(), protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.ErrorIs(t, err, errNothingToPack) } func TestPack1RTTPacketWithData(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false) f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendFrames( tp.framer, []ackhandler.Frame{ {Frame: &wire.ResetStreamFrame{}, Handler: &mtuFinderAckHandler{}}, // set any non-nil ackhandler.FrameHandler {Frame: &wire.MaxDataFrame{}}, }, []ackhandler.StreamFrame{{Frame: f}}, ) buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) p, err := tp.packer.AppendPacket(buffer, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) require.Len(t, p.StreamFrames, 1) var sawResetStream, sawMaxData bool for _, frame := range p.Frames { switch frame.Frame.(type) { case *wire.ResetStreamFrame: sawResetStream = true require.Equal(t, frame.Handler, &mtuFinderAckHandler{}) case *wire.MaxDataFrame: sawMaxData = true require.NotNil(t, frame.Handler) require.NotEqual(t, frame.Handler, &mtuFinderAckHandler{}) } } require.True(t, sawResetStream) require.True(t, sawMaxData) require.Equal(t, f.StreamID, p.StreamFrames[0].Frame.StreamID) require.Equal(t, buffer.Data[:6], []byte("foobar")) // make sure the packet was actually appended require.Contains(t, string(buffer.Data), string(b)) } func TestPack1RTTPacketWithACK(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} tp.framer.EXPECT().HasData() tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true).Return(ack) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) p, err := tp.packer.AppendPacket(getPacketBuffer(), protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Equal(t, ack, p.Ack) } func TestPackPathChallengeAndPathResponse(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false) frames := []ackhandler.Frame{ {Frame: &wire.PathChallengeFrame{}}, {Frame: &wire.PathResponseFrame{}}, {Frame: &wire.DataBlockedFrame{}}, } expectAppendFrames(tp.framer, frames, nil) buffer := getPacketBuffer() p, err := tp.packer.AppendPacket(buffer, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Len(t, p.Frames, 3) var sawPathChallenge, sawPathResponse bool for _, f := range p.Frames { switch f.Frame.(type) { case *wire.PathChallengeFrame: sawPathChallenge = true // this means that the frame won't be retransmitted. require.Nil(t, f.Handler) case *wire.PathResponseFrame: sawPathResponse = true // this means that the frame won't be retransmitted. require.Nil(t, f.Handler) default: require.NotNil(t, f.Handler) } } require.True(t, sawPathChallenge) require.True(t, sawPathResponse) require.NotZero(t, buffer.Len()) } func TestPackDatagramFrames(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.datagramQueue.Add(&wire.DatagramFrame{ DataLenPresent: true, Data: []byte("foobar"), }) tp.framer.EXPECT().HasData() buffer := getPacketBuffer() p, err := tp.packer.AppendPacket(buffer, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Len(t, p.Frames, 1) require.IsType(t, &wire.DatagramFrame{}, p.Frames[0].Frame) require.Equal(t, []byte("foobar"), p.Frames[0].Frame.(*wire.DatagramFrame).Data) require.NotEmpty(t, buffer.Data) } func TestPackLargeDatagramFrame(t *testing.T) { // If a packet contains an ACK, and doesn't have enough space for the DATAGRAM frame, // it should be skipped. It will be packed in the next packet. const maxPacketSize = 1000 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}}) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) f := &wire.DatagramFrame{DataLenPresent: true, Data: make([]byte, maxPacketSize-10)} tp.datagramQueue.Add(f) tp.framer.EXPECT().HasData() buffer := getPacketBuffer() p, err := tp.packer.AppendPacket(buffer, maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p.Ack) require.Empty(t, p.Frames) require.NotEmpty(t, buffer.Data) require.Equal(t, f, tp.datagramQueue.Peek()) // make sure the frame is still there // Now try packing again, but with a smaller packet size. // The DATAGRAM frame should now be dropped, as we can't expect to ever be able tosend it out. const newMaxPacketSize = maxPacketSize - 10 tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), true) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData() buffer = getPacketBuffer() p, err = tp.packer.AppendPacket(buffer, newMaxPacketSize, monotime.Now(), protocol.Version1) require.ErrorIs(t, err, errNothingToPack) require.Nil(t, tp.datagramQueue.Peek()) // make sure the frame is gone } func TestPackRetransmissions(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) f := &wire.CryptoFrame{Data: []byte("Initial")} tp.retransmissionQueue.addInitial(f) tp.retransmissionQueue.addHandshake(&wire.CryptoFrame{Data: []byte("Handshake")}) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, gomock.Any(), false) p, err := tp.packer.PackCoalescedPacket(false, 1000, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Len(t, p.longHdrPackets, 1) require.Equal(t, protocol.EncryptionInitial, p.longHdrPackets[0].EncryptionLevel()) require.Len(t, p.longHdrPackets[0].frames, 1) require.Equal(t, f, p.longHdrPackets[0].frames[0].Frame) require.NotNil(t, p.longHdrPackets[0].frames[0].Handler) } func packMaxNumNonAckElicitingAcks(t *testing.T, tp *testPacketPacker, mockCtrl *gomock.Controller, maxPacketSize protocol.ByteCount) { t.Helper() for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ { tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false).Return( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, ) expectAppendFrames(tp.framer, nil, nil) p, err := tp.packer.AppendPacket(getPacketBuffer(), maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p.Ack) require.Empty(t, p.Frames) } } func TestPackEvery20thPacketAckEliciting(t *testing.T) { const maxPacketSize = 1000 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) // send the maximum number of non-ACK-eliciting packets packMaxNumNonAckElicitingAcks(t, tp, mockCtrl, maxPacketSize) // Now there's nothing to send, so we shouldn't generate a packet just to send a PING tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false) tp.framer.EXPECT().HasData().Return(true) expectAppendFrames(tp.framer, nil, nil) _, err := tp.packer.AppendPacket(getPacketBuffer(), maxPacketSize, monotime.Now(), protocol.Version1) require.ErrorIs(t, err, errNothingToPack) // Now we have an ACK to send. We should bundle a PING to make the packet ack-eliciting. tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false).Return( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, ) expectAppendFrames(tp.framer, nil, nil) p, err := tp.packer.AppendPacket(getPacketBuffer(), maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Len(t, p.Frames, 1) require.Equal(t, &wire.PingFrame{}, p.Frames[0].Frame) require.Nil(t, p.Frames[0].Handler) // make sure the PING is not retransmitted if lost // make sure the next packet doesn't contain another PING tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false).Return( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, ) expectAppendFrames(tp.framer, nil, nil) p, err = tp.packer.AppendPacket(getPacketBuffer(), maxPacketSize, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p.Ack) require.Empty(t, p.Frames) } func TestPackLongHeaderPadToAtLeast4Bytes(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) tp.pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) sealer := newMockShortHeaderSealer(mockCtrl) tp.sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) tp.sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) tp.sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) tp.retransmissionQueue.addHandshake(&wire.PingFrame{}) tp.ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, gomock.Any(), false) packet, err := tp.packer.PackCoalescedPacket(false, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, packet) require.Len(t, packet.longHdrPackets, 1) require.Nil(t, packet.shortHdrPacket) hdr, _, _, err := wire.ParsePacket(packet.buffer.Data) require.NoError(t, err) data := packet.buffer.Data extHdr, err := hdr.ParseExtended(data) require.NoError(t, err) require.Equal(t, protocol.PacketNumberLen1, extHdr.PacketNumberLen) data = data[extHdr.ParsedLen():] require.Len(t, data, 4-1 /* packet number length */ +sealer.Overhead()) // first bytes should be 2 PADDING frames... require.Equal(t, []byte{0, 0}, data[:2]) // ...followed by the PING frame frameParser := wire.NewFrameParser(false, false, false) frameType, lt, err := frameParser.ParseType(data[2:], protocol.EncryptionHandshake) require.NoError(t, err) require.Equal(t, 1, lt) frame, l, err := frameParser.ParseLessCommonFrame(frameType, data[2+lt:], protocol.Version1) require.NoError(t, err) require.IsType(t, &wire.PingFrame{}, frame) require.Zero(t, l) require.Equal(t, sealer.Overhead(), len(data)-2-lt) } func TestPackShortHeaderPadToAtLeast4Bytes(t *testing.T) { // small stream ID, such that only a single byte is consumed f := &wire.StreamFrame{StreamID: 0x10, Fin: true} require.Equal(t, protocol.ByteCount(2), f.Length(protocol.Version1)) mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealer := newMockShortHeaderSealer(mockCtrl) tp.sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) tp.framer.EXPECT().HasData().Return(true) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false) expectAppendFrames(tp.framer, nil, []ackhandler.StreamFrame{{Frame: f}}) buffer := getPacketBuffer() _, err := tp.packer.AppendPacket(buffer, protocol.MaxByteCount, monotime.Now(), protocol.Version1) require.NoError(t, err) // cut off the tag that the mock sealer added buffer.Data = buffer.Data[:buffer.Len()-protocol.ByteCount(sealer.Overhead())] data := buffer.Data l, _, pnLen, _, err := wire.ParseShortHeader(data, testPackerConnIDLen) require.NoError(t, err) payload := data[l:] require.Equal(t, protocol.PacketNumberLen1, pnLen) require.Equal(t, 4-1 /* packet number length */, len(payload)) // the first byte of the payload should be a PADDING frame... require.Equal(t, byte(0), payload[0]) // ... followed by the STREAM frame frameParser := wire.NewFrameParser(false, false, false) frameType, l, err := frameParser.ParseType(payload[1:], protocol.Encryption1RTT) require.NoError(t, err) require.Equal(t, 1, l) require.True(t, frameType.IsStreamFrameType()) frame, frameLen, err := wire.ParseStreamFrame(payload[1+l:], frameType, protocol.Version1) require.NoError(t, err) require.Equal(t, f, frame) require.Equal(t, len(payload)-2, frameLen) } func TestPackInitialProbePacket(t *testing.T) { t.Run("client", func(t *testing.T) { t.Setenv(disableClientHelloScramblingEnv, "true") testPackProbePacket(t, protocol.EncryptionInitial, protocol.PerspectiveClient) }) t.Run("server", func(t *testing.T) { testPackProbePacket(t, protocol.EncryptionInitial, protocol.PerspectiveServer) }) } func TestPackHandshakeProbePacket(t *testing.T) { t.Run("client", func(t *testing.T) { testPackProbePacket(t, protocol.EncryptionHandshake, protocol.PerspectiveClient) }) t.Run("server", func(t *testing.T) { testPackProbePacket(t, protocol.EncryptionHandshake, protocol.PerspectiveServer) }) } func testPackProbePacket(t *testing.T, encLevel protocol.EncryptionLevel, perspective protocol.Perspective) { const maxPacketSize protocol.ByteCount = 1234 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, perspective) var cryptoData []byte switch encLevel { case protocol.EncryptionInitial: tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) cryptoData = getClientHello(t, "") tp.packer.initialStream.Write(cryptoData) case protocol.EncryptionHandshake: tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) cryptoData = []byte("foobar") tp.packer.handshakeStream.Write(cryptoData) } tp.ackFramer.EXPECT().GetAckFrame(encLevel, gomock.Any(), false) tp.pnManager.EXPECT().PeekPacketNumber(encLevel).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(encLevel).Return(protocol.PacketNumber(0x42)) p, err := tp.packer.PackPTOProbePacket(encLevel, maxPacketSize, false, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p) require.Len(t, p.longHdrPackets, 1) packet := p.longHdrPackets[0] require.Equal(t, encLevel, packet.EncryptionLevel()) if encLevel == protocol.EncryptionInitial { require.GreaterOrEqual(t, p.buffer.Len(), protocol.ByteCount(protocol.MinInitialPacketSize)) require.Equal(t, maxPacketSize, p.buffer.Len()) } require.Len(t, packet.frames, 1) require.Equal(t, cryptoData, packet.frames[0].Frame.(*wire.CryptoFrame).Data) hdrs, more := parsePacket(t, p.buffer.Data) require.Len(t, hdrs, 1) switch encLevel { case protocol.EncryptionInitial: require.Equal(t, protocol.PacketTypeInitial, hdrs[0].Type) case protocol.EncryptionHandshake: require.Equal(t, protocol.PacketTypeHandshake, hdrs[0].Type) } require.Empty(t, more) } func TestPack1RTTProbePacket(t *testing.T) { const maxPacketSize protocol.ByteCount = 999 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, gomock.Any(), false) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) tp.framer.EXPECT().HasData().Return(true) tp.framer.EXPECT().Append(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn( func(cf []ackhandler.Frame, sf []ackhandler.StreamFrame, size protocol.ByteCount, _ monotime.Time, v protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) { f, split := (&wire.StreamFrame{Data: make([]byte, 2*maxPacketSize)}).MaybeSplitOffFrame(size, v) require.True(t, split) return cf, append(sf, ackhandler.StreamFrame{Frame: f}), f.Length(v) }, ) p, err := tp.packer.PackPTOProbePacket(protocol.Encryption1RTT, maxPacketSize, false, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, p) require.True(t, p.IsOnlyShortHeaderPacket()) require.Empty(t, p.longHdrPackets) require.NotNil(t, p.shortHdrPacket) packet := p.shortHdrPacket require.Empty(t, packet.Frames) require.Len(t, packet.StreamFrames, 1) require.Equal(t, maxPacketSize, packet.Length) } func TestPackPTOProbePacketNothingToPack(t *testing.T) { t.Run("Initial", func(t *testing.T) { testPackPTOProbePacketNothingToPack(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testPackPTOProbePacketNothingToPack(t, protocol.EncryptionHandshake) }) t.Run("1-RTT", func(t *testing.T) { testPackPTOProbePacketNothingToPack(t, protocol.Encryption1RTT) }) } func testPackPTOProbePacketNothingToPack(t *testing.T, encLevel protocol.EncryptionLevel) { const maxPacketSize protocol.ByteCount = 1234 mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) switch encLevel { case protocol.EncryptionInitial: tp.sealingManager.EXPECT().GetInitialSealer().Return(newMockShortHeaderSealer(mockCtrl), nil).Times(2) case protocol.EncryptionHandshake: tp.sealingManager.EXPECT().GetHandshakeSealer().Return(newMockShortHeaderSealer(mockCtrl), nil).Times(2) case protocol.Encryption1RTT: tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil).Times(2) tp.framer.EXPECT().HasData().Times(2) } tp.pnManager.EXPECT().PeekPacketNumber(encLevel).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).MaxTimes(2) tp.ackFramer.EXPECT().GetAckFrame(encLevel, gomock.Any(), true).Times(2) // don't force a PING to be sent packet, err := tp.packer.PackPTOProbePacket(encLevel, maxPacketSize, false, monotime.Now(), protocol.Version1) require.NoError(t, err) require.Nil(t, packet) // now force a PING to be sent tp.pnManager.EXPECT().PopPacketNumber(encLevel).Return(protocol.PacketNumber(0x42)) packet, err = tp.packer.PackPTOProbePacket(encLevel, maxPacketSize, true, monotime.Now(), protocol.Version1) require.NoError(t, err) require.NotNil(t, packet) var frames []ackhandler.Frame switch encLevel { case protocol.EncryptionInitial, protocol.EncryptionHandshake: require.Len(t, packet.longHdrPackets, 1) require.Nil(t, packet.shortHdrPacket) require.Equal(t, encLevel, packet.longHdrPackets[0].EncryptionLevel()) frames = packet.longHdrPackets[0].frames case protocol.Encryption1RTT: require.Empty(t, packet.longHdrPackets) require.NotNil(t, packet.shortHdrPacket) frames = packet.shortHdrPacket.Frames } require.Len(t, frames, 1) require.Equal(t, &wire.PingFrame{}, frames[0].Frame) require.Equal(t, emptyHandler{}, frames[0].Handler) } func TestPackMTUProbePacket(t *testing.T) { const ( maxPacketSize protocol.ByteCount = 1000 probePacketSize = maxPacketSize + 42 ) mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveClient) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} p, buffer, err := tp.packer.PackMTUProbePacket(ping, probePacketSize, protocol.Version1) require.NoError(t, err) require.Equal(t, probePacketSize, p.Length) require.Equal(t, protocol.PacketNumber(0x43), p.PacketNumber) require.Len(t, buffer.Data, int(probePacketSize)) require.True(t, p.IsPathMTUProbePacket) require.False(t, p.IsPathProbePacket) } func TestPackPathProbePacket(t *testing.T) { mockCtrl := gomock.NewController(t) tp := newTestPacketPacker(t, mockCtrl, protocol.PerspectiveServer) tp.sealingManager.EXPECT().Get1RTTSealer().Return(newMockShortHeaderSealer(mockCtrl), nil) tp.pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43), protocol.PacketNumberLen2) tp.pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x43)) p, buf, err := tp.packer.PackPathProbePacket( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), []ackhandler.Frame{ {Frame: &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, {Frame: &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}}, }, protocol.Version1, ) require.NoError(t, err) require.Equal(t, protocol.PacketNumber(0x43), p.PacketNumber) require.Nil(t, p.Ack) require.Empty(t, p.StreamFrames) require.Len(t, p.Frames, 2) // the frame order is randomized frames := []wire.Frame{p.Frames[0].Frame, p.Frames[1].Frame} require.Contains(t, frames, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}) require.Contains(t, frames, &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}) require.Len(t, buf.Data, protocol.MinInitialPacketSize) require.True(t, p.IsPathProbePacket) require.False(t, p.IsPathMTUProbePacket) } quic-go-0.59.0/packet_unpacker.go000066400000000000000000000170571513066070600166550ustar00rootroot00000000000000package quic import ( "fmt" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) type headerDecryptor interface { DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } type headerParseError struct { err error } func (e *headerParseError) Unwrap() error { return e.err } func (e *headerParseError) Error() string { return e.err.Error() } type unpackedPacket struct { hdr *wire.ExtendedHeader encryptionLevel protocol.EncryptionLevel data []byte } // The packetUnpacker unpacks QUIC packets. type packetUnpacker struct { cs handshake.CryptoSetup shortHdrConnIDLen int } var _ unpacker = &packetUnpacker{} func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker { return &packetUnpacker{ cs: cs, shortHdrConnIDLen: shortHdrConnIDLen, } } // UnpackLongHeader unpacks a Long Header packet. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If any other error occurred when parsing the header, the error is of type headerParseError. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD. func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) { var encLevel protocol.EncryptionLevel var extHdr *wire.ExtendedHeader var decrypted []byte //nolint:exhaustive // Retry packets can't be unpacked. switch hdr.Type { case protocol.PacketTypeInitial: encLevel = protocol.EncryptionInitial opener, err := u.cs.GetInitialOpener() if err != nil { return nil, err } extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) if err != nil { return nil, err } case protocol.PacketTypeHandshake: encLevel = protocol.EncryptionHandshake opener, err := u.cs.GetHandshakeOpener() if err != nil { return nil, err } extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) if err != nil { return nil, err } case protocol.PacketType0RTT: encLevel = protocol.Encryption0RTT opener, err := u.cs.Get0RTTOpener() if err != nil { return nil, err } extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) if err != nil { return nil, err } default: return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) } if len(decrypted) == 0 { return nil, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "empty packet", } } return &unpackedPacket{ hdr: extHdr, encryptionLevel: encLevel, data: decrypted, }, nil } func (u *packetUnpacker) UnpackShortHeader(rcvTime monotime.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { opener, err := u.cs.Get1RTTOpener() if err != nil { return 0, 0, 0, nil, err } pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) if err != nil { return 0, 0, 0, nil, err } if len(decrypted) == 0 { return 0, 0, 0, nil, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "empty packet", } } return pn, pnLen, kp, decrypted, nil } func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { extHdr, parseErr := u.unpackLongHeader(opener, hdr, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, nil, parseErr } extHdrLen := extHdr.ParsedLen() extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) if err != nil { return nil, nil, err } if parseErr != nil { return nil, nil, parseErr } return extHdr, decrypted, nil } func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime monotime.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return 0, 0, 0, nil, &headerParseError{parseErr} } pn = opener.DecodePacketNumber(pn, pnLen) decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l]) if err != nil { return 0, 0, 0, nil, err } return pn, pnLen, kp, decrypted, parseErr } func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) { hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen if len(data) < hdrLen+4+16 { return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) } origPNBytes := make([]byte, 4) copy(origPNBytes, data[hdrLen:hdrLen+4]) // 2. decrypt the header, assuming a 4 byte packet number hd.DecryptHeader( data[hdrLen+4:hdrLen+4+16], &data[0], data[hdrLen:hdrLen+4], ) // 3. parse the header (and learn the actual length of the packet number) l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return l, pn, pnLen, kp, parseErr } // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier if pnLen != protocol.PacketNumberLen4 { copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):]) } return l, pn, pnLen, kp, parseErr } // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { extHdr, err := unpackLongHeader(hd, hdr, data) if err != nil && err != wire.ErrInvalidReservedBits { return nil, &headerParseError{err: err} } return extHdr, err } func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { hdrLen := hdr.ParsedLen() if protocol.ByteCount(len(data)) < hdrLen+4+16 { return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen) } // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it. // 1. save a copy of the 4 bytes origPNBytes := make([]byte, 4) copy(origPNBytes, data[hdrLen:hdrLen+4]) // 2. decrypt the header, assuming a 4 byte packet number hd.DecryptHeader( data[hdrLen+4:hdrLen+4+16], &data[0], data[hdrLen:hdrLen+4], ) // 3. parse the header (and learn the actual length of the packet number) extHdr, parseErr := hdr.ParseExtended(data) if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, parseErr } // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) } return extHdr, parseErr } quic-go-0.59.0/packet_unpacker_test.go000066400000000000000000000336251513066070600177130ustar00rootroot00000000000000package quic import ( "crypto/rand" "testing" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type decryptResult struct { decrypted []byte err error } func TestUnpackLongHeaderPacket(t *testing.T) { b := []byte("decrypted") t.Run("Initial", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.EncryptionInitial, false, decryptResult{decrypted: b}, nil) }) t.Run("Handshake", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.EncryptionHandshake, false, decryptResult{decrypted: b}, nil) }) t.Run("0-RTT", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.Encryption0RTT, false, decryptResult{decrypted: b}, nil) }) } func TestUnpackLongHeaderIncorrectReservedBits(t *testing.T) { t.Run("decryption fails", func(t *testing.T) { testUnpackLongHeaderIncorrectReservedBits(t, true) }) t.Run("decryption succeeds", func(t *testing.T) { testUnpackLongHeaderIncorrectReservedBits(t, false) }) } // Even if the reserved bits are wrong, we still need to continue processing the header. // This helps prevent a timing side-channel attack, see section 9.5 of RFC 9001. // We should only return a ErrInvalidReservedBits error if the decryption succeeds, // as this shows that the peer actually sent an invalid packet. // However, if decryption fails, this packet is likely injected by an attacker, // and we should treat it as any other undecryptable packet. func testUnpackLongHeaderIncorrectReservedBits(t *testing.T, decryptionSucceeds bool) { decrypted := []byte("decrypted") expectedErr := wire.ErrInvalidReservedBits decryptResult := decryptResult{decrypted: decrypted} if !decryptionSucceeds { decryptResult.err = handshake.ErrDecryptionFailed expectedErr = handshake.ErrDecryptionFailed } t.Run("Initial", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.EncryptionInitial, true, decryptResult, expectedErr) }) t.Run("Handshake", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.EncryptionHandshake, true, decryptResult, expectedErr) }) t.Run("0-RTT", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.Encryption0RTT, true, decryptResult, expectedErr) }) } func TestUnpackLongHeaderEmptyPayload(t *testing.T) { expectedErr := &qerr.TransportError{ErrorCode: qerr.ProtocolViolation} t.Run("Initial", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.EncryptionInitial, false, decryptResult{}, expectedErr) }) t.Run("Handshake", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.EncryptionHandshake, false, decryptResult{}, expectedErr) }) t.Run("0-RTT", func(t *testing.T) { testUnpackLongHeaderPacket(t, protocol.Encryption0RTT, false, decryptResult{}, expectedErr) }) } func testUnpackLongHeaderPacket(t *testing.T, encLevel protocol.EncryptionLevel, incorrectReservedBits bool, decryptResult decryptResult, expectedErr error, ) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := newPacketUnpacker(cs, 4) var packetType protocol.PacketType switch encLevel { case protocol.EncryptionInitial: packetType = protocol.PacketTypeInitial case protocol.EncryptionHandshake: packetType = protocol.PacketTypeHandshake case protocol.Encryption0RTT: packetType = protocol.PacketType0RTT } payload := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") extHdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: packetType, Length: protocol.ByteCount(3 + len(payload)), // packet number len + payload DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), Version: protocol.Version1, }, PacketNumber: 2, PacketNumberLen: 3, } hdrRaw, err := extHdr.Append(nil, protocol.Version1) require.NoError(t, err) if incorrectReservedBits { hdrRaw[0] |= 0xc } data := append(hdrRaw, payload...) hdr, _, _, err := wire.ParsePacket(data) require.NoError(t, err) opener := mocks.NewMockLongHeaderOpener(mockCtrl) var calls []any switch encLevel { case protocol.EncryptionInitial: calls = append(calls, cs.EXPECT().GetInitialOpener().Return(opener, nil)) case protocol.EncryptionHandshake: calls = append(calls, cs.EXPECT().GetHandshakeOpener().Return(opener, nil)) case protocol.Encryption0RTT: calls = append(calls, cs.EXPECT().Get0RTTOpener().Return(opener, nil)) } calls = append(calls, []any{ opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)), opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return( decryptResult.decrypted, decryptResult.err, ), }...) gomock.InOrder(calls...) packet, err := unpacker.UnpackLongHeader(hdr, data) if expectedErr != nil { require.ErrorIs(t, err, expectedErr) return } require.NoError(t, err) require.Equal(t, encLevel, packet.encryptionLevel) require.Equal(t, decryptResult.decrypted, packet.data) } func TestUnpackShortHeaderPacket(t *testing.T) { testUnpackShortHeaderPacket(t, false, decryptResult{decrypted: []byte("decrypted")}, nil) } func TestUnpackShortHeaderEmptyPayload(t *testing.T) { testUnpackShortHeaderPacket(t, false, decryptResult{}, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) } // Even if the reserved bits are wrong, we still need to continue processing the header. // This helps prevent a timing side-channel attack, see section 9.5 of RFC 9001. // We should only return a ErrInvalidReservedBits error if the decryption succeeds, // as this shows that the peer actually sent an invalid packet. // However, if decryption fails, this packet is likely injected by an attacker, // and we should treat it as any other undecryptable packet. func TestUnpackShortHeaderIncorrectReservedBits(t *testing.T) { t.Run("decryption fails", func(t *testing.T) { testUnpackShortHeaderPacket(t, true, decryptResult{err: handshake.ErrDecryptionFailed}, handshake.ErrDecryptionFailed, ) }) t.Run("decryption succeeds", func(t *testing.T) { testUnpackShortHeaderPacket(t, true, decryptResult{decrypted: []byte("decrypted")}, wire.ErrInvalidReservedBits, ) }) } func testUnpackShortHeaderPacket(t *testing.T, incorrectReservedBits bool, decryptResult decryptResult, expectedErr error) { mockCtrl := gomock.NewController(t) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := newPacketUnpacker(cs, connID.Len()) payload := []byte("Lorem ipsum dolor sit amet") hdrRaw, err := wire.AppendShortHeader( nil, connID, 0x1337, protocol.PacketNumberLen3, protocol.KeyPhaseOne, ) require.NoError(t, err) if incorrectReservedBits { hdrRaw[0] |= 0x18 } opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get1RTTOpener().Return(opener, nil) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1234)) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( decryptResult.decrypted, decryptResult.err, ) pn, pnLen, kp, data, err := unpacker.UnpackShortHeader(monotime.Now(), append(hdrRaw, payload...)) if expectedErr != nil { require.ErrorIs(t, err, expectedErr) return } require.NoError(t, err) require.Equal(t, decryptResult.decrypted, data) require.Equal(t, protocol.PacketNumber(1234), pn) require.Equal(t, protocol.PacketNumberLen3, pnLen) require.Equal(t, protocol.KeyPhaseOne, kp) } func TestUnpackHeaderSampleLongHeader(t *testing.T) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := newPacketUnpacker(cs, 4) extHdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), Version: protocol.Version1, }, PacketNumber: 1337, PacketNumberLen: protocol.PacketNumberLen2, } data, err := extHdr.Append(nil, protocol.Version1) require.NoError(t, err) b := make([]byte, 2+16) // 2 bytes to fill up the packet number, 16 bytes for the sample rand.Read(b) data = append(data, b...) hdr, _, _, err := wire.ParsePacket(data) require.NoError(t, err) t.Run("too short", func(t *testing.T) { cs.EXPECT().GetHandshakeOpener().Return(mocks.NewMockLongHeaderOpener(mockCtrl), nil) _, err = unpacker.UnpackLongHeader(hdr, data[:len(data)-1]) require.IsType(t, &headerParseError{}, err) require.ErrorContains(t, err, "packet too small, expected at least 20 bytes after the header, got 19") }) t.Run("minimal size", func(t *testing.T) { opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(b[len(b)-16:], gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1337)) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("decrypted"), nil) _, err = unpacker.UnpackLongHeader(hdr, data) require.NoError(t, err) }) } func TestUnpackHeaderSampleShortHeader(t *testing.T) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := newPacketUnpacker(cs, 4) data, err := wire.AppendShortHeader( nil, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne, ) require.NoError(t, err) b := make([]byte, 2+16) // 2 bytes to fill up the packet number, 16 bytes for the sample rand.Read(b) data = append(data, b...) t.Run("too short", func(t *testing.T) { cs.EXPECT().Get1RTTOpener().Return(mocks.NewMockShortHeaderOpener(mockCtrl), nil) _, _, _, _, err = unpacker.UnpackShortHeader(monotime.Now(), data[:len(data)-1]) require.IsType(t, &headerParseError{}, err) require.ErrorContains(t, err, "packet too small, expected at least 20 bytes after the header, got 19") }) t.Run("minimal size", func(t *testing.T) { opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(data[len(data)-16:], gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1337)) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("decrypted"), nil) _, _, _, _, err = unpacker.UnpackShortHeader(monotime.Now(), data) require.NoError(t, err) }) } func TestUnpackErrors(t *testing.T) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := newPacketUnpacker(cs, 4) // opener not available cs.EXPECT().GetHandshakeOpener().Return(nil, handshake.ErrKeysNotYetAvailable) _, err := unpacker.UnpackLongHeader(&wire.Header{Type: protocol.PacketTypeHandshake}, []byte("foobar")) require.ErrorIs(t, err, handshake.ErrKeysNotYetAvailable) // opener returns error opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1234)) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}) _, err = unpacker.UnpackLongHeader(&wire.Header{Type: protocol.PacketTypeHandshake}, make([]byte, 100)) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded}) } func TestUnpackHeaderDecryption(t *testing.T) { mockCtrl := gomock.NewController(t) cs := mocks.NewMockCryptoSetup(mockCtrl) unpacker := newPacketUnpacker(cs, 4) connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) extHdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, Length: 2, // packet number len DestConnectionID: connID, Version: protocol.Version1, }, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, } hdrRaw, err := extHdr.Append(nil, protocol.Version1) require.NoError(t, err) hdr, _, _, err := wire.ParsePacket(hdrRaw) require.NoError(t, err) origHdrRaw := append([]byte{}, hdrRaw...) // save a copy of the header firstHdrByte := hdrRaw[0] hdrRaw[0] ^= 0xff // invert the first byte hdrRaw[len(hdrRaw)-2] ^= 0xff // invert the packet number hdrRaw[len(hdrRaw)-1] ^= 0xff // invert the packet number require.NotEqual(t, hdrRaw[0], firstHdrByte) opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) gomock.InOrder( // we're using a 2 byte packet number, so the sample starts at the 3rd payload byte opener.EXPECT().DecryptHeader( []byte{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, &hdrRaw[0], append(hdrRaw[len(hdrRaw)-2:], []byte{1, 2}...)).Do(func(_ []byte, firstByte *byte, pnBytes []byte) { *firstByte ^= 0xff // invert the first byte back for i := range pnBytes { pnBytes[i] ^= 0xff // invert the packet number bytes } }), opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)), opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil), ) data := hdrRaw for i := 1; i <= 100; i++ { data = append(data, uint8(i)) } packet, err := unpacker.UnpackLongHeader(hdr, data) require.NoError(t, err) require.Equal(t, protocol.PacketNumber(0x7331), packet.hdr.PacketNumber) } quic-go-0.59.0/path_manager.go000066400000000000000000000130241513066070600161320ustar00rootroot00000000000000package quic import ( "crypto/rand" "net" "slices" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) type pathID int64 const invalidPathID pathID = -1 // Maximum number of paths to keep track of. // If the peer probes another path (before the pathTimeout of an existing path expires), // this probing attempt is ignored. const maxPaths = 3 // If no packet is received for a path for pathTimeout, // the path can be evicted when the peer probes another path. // This prevents an attacker from churning through paths by duplicating packets and // sending them with spoofed source addresses. const pathTimeout = 5 * time.Second type path struct { id pathID addr net.Addr lastPacketTime monotime.Time pathChallenge [8]byte validated bool rcvdNonProbing bool } type pathManager struct { nextPathID pathID // ordered by lastPacketTime, with the most recently used path at the end paths []*path getConnID func(pathID) (_ protocol.ConnectionID, ok bool) retireConnID func(pathID) logger utils.Logger } func newPathManager( getConnID func(pathID) (_ protocol.ConnectionID, ok bool), retireConnID func(pathID), logger utils.Logger, ) *pathManager { return &pathManager{ paths: make([]*path, 0, maxPaths+1), getConnID: getConnID, retireConnID: retireConnID, logger: logger, } } // Returns a path challenge frame if one should be sent. // May return nil. func (pm *pathManager) HandlePacket( remoteAddr net.Addr, t monotime.Time, pathChallenge *wire.PathChallengeFrame, // may be nil if the packet didn't contain a PATH_CHALLENGE isNonProbing bool, ) (_ protocol.ConnectionID, _ []ackhandler.Frame, shouldSwitch bool) { var p *path for i, path := range pm.paths { if addrsEqual(path.addr, remoteAddr) { p = path p.lastPacketTime = t // already sent a PATH_CHALLENGE for this path if isNonProbing { path.rcvdNonProbing = true } if pm.logger.Debug() { pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", remoteAddr, path.validated) } shouldSwitch = path.validated && path.rcvdNonProbing if i != len(pm.paths)-1 { // move the path to the end of the list pm.paths = slices.Delete(pm.paths, i, i+1) pm.paths = append(pm.paths, p) } if pathChallenge == nil { return protocol.ConnectionID{}, nil, shouldSwitch } } } if len(pm.paths) >= maxPaths { if pm.paths[0].lastPacketTime.Add(pathTimeout).After(t) { if pm.logger.Debug() { pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", remoteAddr, len(pm.paths)) } return protocol.ConnectionID{}, nil, shouldSwitch } // evict the oldest path, if the last packet was received more than pathTimeout ago pm.retireConnID(pm.paths[0].id) pm.paths = pm.paths[1:] } var pathID pathID if p != nil { pathID = p.id } else { pathID = pm.nextPathID } // previously unseen path, initiate path validation by sending a PATH_CHALLENGE connID, ok := pm.getConnID(pathID) if !ok { pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", remoteAddr) return protocol.ConnectionID{}, nil, shouldSwitch } frames := make([]ackhandler.Frame, 0, 2) if p == nil { var pathChallengeData [8]byte rand.Read(pathChallengeData[:]) p = &path{ id: pm.nextPathID, addr: remoteAddr, lastPacketTime: t, rcvdNonProbing: isNonProbing, pathChallenge: pathChallengeData, } pm.nextPathID++ pm.paths = append(pm.paths, p) frames = append(frames, ackhandler.Frame{ Frame: &wire.PathChallengeFrame{Data: p.pathChallenge}, Handler: (*pathManagerAckHandler)(pm), }) pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", remoteAddr) } if pathChallenge != nil { frames = append(frames, ackhandler.Frame{ Frame: &wire.PathResponseFrame{Data: pathChallenge.Data}, Handler: (*pathManagerAckHandler)(pm), }) } return connID, frames, shouldSwitch } func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) { for _, p := range pm.paths { if f.Data == p.pathChallenge { // path validated p.validated = true pm.logger.Debugf("path %s validated", p.addr) break } } } // SwitchToPath is called when the connection switches to a new path func (pm *pathManager) SwitchToPath(addr net.Addr) { // retire all other paths for _, path := range pm.paths { if addrsEqual(path.addr, addr) { pm.logger.Debugf("switching to path %d (%s)", path.id, addr) continue } pm.retireConnID(path.id) } clear(pm.paths) pm.paths = pm.paths[:0] } type pathManagerAckHandler pathManager var _ ackhandler.FrameHandler = &pathManagerAckHandler{} // Acknowledging the frame doesn't validate the path, only receiving the PATH_RESPONSE does. func (pm *pathManagerAckHandler) OnAcked(f wire.Frame) {} func (pm *pathManagerAckHandler) OnLost(f wire.Frame) { pc, ok := f.(*wire.PathChallengeFrame) if !ok { return } for i, path := range pm.paths { if path.pathChallenge == pc.Data { pm.paths = slices.Delete(pm.paths, i, i+1) pm.retireConnID(path.id) break } } } func addrsEqual(addr1, addr2 net.Addr) bool { if addr1 == nil || addr2 == nil { return false } a1, ok1 := addr1.(*net.UDPAddr) a2, ok2 := addr2.(*net.UDPAddr) if ok1 && ok2 { return a1.IP.Equal(a2.IP) && a1.Port == a2.Port } return addr1.String() == addr2.String() } quic-go-0.59.0/path_manager_outgoing.go000066400000000000000000000166071513066070600200570ustar00rootroot00000000000000package quic import ( "context" "crypto/rand" "errors" "slices" "sync" "sync/atomic" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) var ( // ErrPathClosed is returned when trying to switch to a path that has been closed. ErrPathClosed = errors.New("path closed") // ErrPathNotValidated is returned when trying to use a path before path probing has completed. ErrPathNotValidated = errors.New("path not yet validated") ) var errPathDoesNotExist = errors.New("path does not exist") // Path is a network path. type Path struct { id pathID pathManager *pathManagerOutgoing tr *Transport initialRTT time.Duration enablePath func() validated atomic.Bool abandon chan struct{} } func (p *Path) Probe(ctx context.Context) error { path := p.pathManager.addPath(p, p.enablePath) p.pathManager.enqueueProbe(p) nextProbeDur := p.initialRTT var timer *time.Timer var timerChan <-chan time.Time for { select { case <-ctx.Done(): return context.Cause(ctx) case <-path.Validated(): p.validated.Store(true) return nil case <-timerChan: nextProbeDur *= 2 // exponential backoff p.pathManager.enqueueProbe(p) case <-path.ProbeSent(): case <-p.abandon: return ErrPathClosed } if timer != nil { timer.Stop() } timer = time.NewTimer(nextProbeDur) timerChan = timer.C } } // Switch switches the QUIC connection to this path. // It immediately stops sending on the old path, and sends on this new path. func (p *Path) Switch() error { if err := p.pathManager.switchToPath(p.id); err != nil { switch { case errors.Is(err, ErrPathNotValidated): return err case errors.Is(err, errPathDoesNotExist) && !p.validated.Load(): select { case <-p.abandon: return ErrPathClosed default: return ErrPathNotValidated } default: return ErrPathClosed } } return nil } // Close abandons a path. // It is not possible to close the path that’s currently active. // After closing, it is not possible to probe this path again. func (p *Path) Close() error { select { case <-p.abandon: return nil default: } if err := p.pathManager.removePath(p.id); err != nil { return err } close(p.abandon) return nil } type pathOutgoing struct { pathChallenges [][8]byte // length is implicitly limited by exponential backoff tr *Transport isValidated bool probeSent chan struct{} // receives when a PATH_CHALLENGE is sent validated chan struct{} // closed when the path the corresponding PATH_RESPONSE is received enablePath func() } func (p *pathOutgoing) ProbeSent() <-chan struct{} { return p.probeSent } func (p *pathOutgoing) Validated() <-chan struct{} { return p.validated } type pathManagerOutgoing struct { getConnID func(pathID) (_ protocol.ConnectionID, ok bool) retireConnID func(pathID) scheduleSending func() mx sync.Mutex activePath pathID pathsToProbe []pathID paths map[pathID]*pathOutgoing nextPathID pathID pathToSwitchTo *pathOutgoing } // newPathManagerOutgoing creates a new pathManagerOutgoing object. This // function must be side-effect free as it may be called multiple times for a // single connection. func newPathManagerOutgoing( getConnID func(pathID) (_ protocol.ConnectionID, ok bool), retireConnID func(pathID), scheduleSending func(), ) *pathManagerOutgoing { return &pathManagerOutgoing{ activePath: 0, // at initialization time, we're guaranteed to be using the handshake path nextPathID: 1, getConnID: getConnID, retireConnID: retireConnID, scheduleSending: scheduleSending, paths: make(map[pathID]*pathOutgoing, 4), } } func (pm *pathManagerOutgoing) addPath(p *Path, enablePath func()) *pathOutgoing { pm.mx.Lock() defer pm.mx.Unlock() // path might already exist, and just being re-probed if existingPath, ok := pm.paths[p.id]; ok { existingPath.validated = make(chan struct{}) return existingPath } path := &pathOutgoing{ tr: p.tr, probeSent: make(chan struct{}, 1), validated: make(chan struct{}), enablePath: enablePath, } pm.paths[p.id] = path return path } func (pm *pathManagerOutgoing) enqueueProbe(p *Path) { pm.mx.Lock() pm.pathsToProbe = append(pm.pathsToProbe, p.id) pm.mx.Unlock() pm.scheduleSending() } func (pm *pathManagerOutgoing) removePath(id pathID) error { if err := pm.removePathImpl(id); err != nil { return err } pm.scheduleSending() return nil } func (pm *pathManagerOutgoing) removePathImpl(id pathID) error { pm.mx.Lock() defer pm.mx.Unlock() if id == pm.activePath { return errors.New("cannot close active path") } p, ok := pm.paths[id] if !ok { return nil } if len(p.pathChallenges) > 0 { pm.retireConnID(id) } delete(pm.paths, id) return nil } func (pm *pathManagerOutgoing) switchToPath(id pathID) error { pm.mx.Lock() defer pm.mx.Unlock() p, ok := pm.paths[id] if !ok { return errPathDoesNotExist } if !p.isValidated { return ErrPathNotValidated } pm.pathToSwitchTo = p pm.activePath = id return nil } func (pm *pathManagerOutgoing) NewPath(t *Transport, initialRTT time.Duration, enablePath func()) *Path { pm.mx.Lock() defer pm.mx.Unlock() id := pm.nextPathID pm.nextPathID++ return &Path{ pathManager: pm, id: id, tr: t, enablePath: enablePath, initialRTT: initialRTT, abandon: make(chan struct{}), } } func (pm *pathManagerOutgoing) NextPathToProbe() (_ protocol.ConnectionID, _ ackhandler.Frame, _ *Transport, hasPath bool) { pm.mx.Lock() defer pm.mx.Unlock() var p *pathOutgoing id := invalidPathID for _, pID := range pm.pathsToProbe { var ok bool p, ok = pm.paths[pID] if ok { id = pID break } // if the path doesn't exist in the map, it might have been abandoned pm.pathsToProbe = pm.pathsToProbe[1:] } if id == invalidPathID { return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false } connID, ok := pm.getConnID(id) if !ok { return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false } var b [8]byte _, _ = rand.Read(b[:]) p.pathChallenges = append(p.pathChallenges, b) pm.pathsToProbe = pm.pathsToProbe[1:] p.enablePath() select { case p.probeSent <- struct{}{}: default: } frame := ackhandler.Frame{ Frame: &wire.PathChallengeFrame{Data: b}, Handler: (*pathManagerOutgoingAckHandler)(pm), } return connID, frame, p.tr, true } func (pm *pathManagerOutgoing) HandlePathResponseFrame(f *wire.PathResponseFrame) { pm.mx.Lock() defer pm.mx.Unlock() for _, p := range pm.paths { if slices.Contains(p.pathChallenges, f.Data) { // path validated if !p.isValidated { // make sure that duplicate PATH_RESPONSE frames are ignored p.isValidated = true p.pathChallenges = nil close(p.validated) } break } } } func (pm *pathManagerOutgoing) ShouldSwitchPath() (*Transport, bool) { pm.mx.Lock() defer pm.mx.Unlock() if pm.pathToSwitchTo == nil { return nil, false } p := pm.pathToSwitchTo pm.pathToSwitchTo = nil return p.tr, true } type pathManagerOutgoingAckHandler pathManagerOutgoing var _ ackhandler.FrameHandler = &pathManagerOutgoingAckHandler{} // OnAcked is called when the PATH_CHALLENGE is acked. // This doesn't validate the path, only receiving the PATH_RESPONSE does. func (pm *pathManagerOutgoingAckHandler) OnAcked(wire.Frame) {} func (pm *pathManagerOutgoingAckHandler) OnLost(wire.Frame) {} quic-go-0.59.0/path_manager_outgoing_test.go000066400000000000000000000170231513066070600211070ustar00rootroot00000000000000package quic import ( "context" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestPathManagerOutgoingPathProbing(t *testing.T) { synctest.Test(t, func(t *testing.T) { connIDs := []protocol.ConnectionID{ protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), } pm := newPathManagerOutgoing( func(id pathID) (protocol.ConnectionID, bool) { connID := connIDs[0] connIDs = connIDs[1:] return connID, true }, func(id pathID) { t.Fatal("didn't expect any connection ID to be retired") }, func() {}, ) _, _, _, ok := pm.NextPathToProbe() require.False(t, ok) tr1 := &Transport{} var enabled bool p := pm.NewPath(tr1, time.Second, func() { enabled = true }) require.ErrorIs(t, p.Switch(), ErrPathNotValidated) errChan := make(chan error, 1) go func() { errChan <- p.Probe(context.Background()) }() // wait for the path to be queued for probing synctest.Wait() require.False(t, enabled) connID, f, tr, ok := pm.NextPathToProbe() require.True(t, ok) require.Equal(t, tr1, tr) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), connID) require.IsType(t, &wire.PathChallengeFrame{}, f.Frame) pc := f.Frame.(*wire.PathChallengeFrame) require.True(t, enabled) _, _, _, ok = pm.NextPathToProbe() require.False(t, ok) select { case <-errChan: t.Fatal("should still be probing") default: } // acking the frame doesn't complete path validation... f.Handler.OnAcked(f.Frame) select { case <-errChan: t.Fatal("should still be probing") default: } require.ErrorIs(t, p.Switch(), ErrPathNotValidated) _, ok = pm.ShouldSwitchPath() require.False(t, ok) // ... neither does receiving a random PATH_RESPONSE... pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: [8]byte{'f', 'o', 'o', 'f', 'o', 'o'}}) f.Handler.OnAcked(f.Frame) // doesn't do anything f.Handler.OnLost(f.Frame) // doesn't do anything select { case <-errChan: t.Fatal("should still be probing") default: } // ... only receiving the corresponding PATH_RESPONSE does pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc.Data}) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("timeout") } // receiving it multiple times is ok pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc.Data}) // now switch to the other path _, ok = pm.ShouldSwitchPath() require.False(t, ok) require.NoError(t, p.Switch()) // the active path can't be closed require.EqualError(t, p.Close(), "cannot close active path") switchToTransport, ok := pm.ShouldSwitchPath() require.True(t, ok) require.Equal(t, tr1, switchToTransport) }) } func TestPathManagerOutgoingRetransmissions(t *testing.T) { synctest.Test(t, func(t *testing.T) { connIDs := []protocol.ConnectionID{ protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), protocol.ParseConnectionID([]byte{2, 3, 4, 5, 6, 7, 8, 9}), } var retiredConnIDs []protocol.ConnectionID scheduledSending := make(chan struct{}, 20) pm := newPathManagerOutgoing( func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true }, func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) }, func() { scheduledSending <- struct{}{} }, ) _, _, _, ok := pm.NextPathToProbe() require.False(t, ok) tr1 := &Transport{} const initialRTT = 5 * time.Millisecond p := pm.NewPath(tr1, initialRTT, func() {}) pathChallengeChan := make(chan [8]byte) done := make(chan struct{}) defer close(done) go func() { for { select { case <-scheduledSending: case <-done: return } _, f, _, ok := pm.NextPathToProbe() if !ok { // should never happen pathChallengeChan <- [8]byte{} continue } pathChallengeChan <- f.Frame.(*wire.PathChallengeFrame).Data } }() errChan := make(chan error, 1) go func() { errChan <- p.Probe(context.Background()) }() start := time.Now() type result struct { pc *[8]byte took time.Duration } var results []result for range 4 { select { case <-errChan: t.Fatal("probing should not have completed") case pc := <-pathChallengeChan: results = append(results, result{pc: &pc, took: time.Since(start)}) case <-time.After(time.Second): t.Fatal("timeout") } } for i, r1 := range results { require.NotNil(t, r1.pc) if i > 0 { took := r1.took - results[i-1].took t.Log("took", took) require.Equal(t, took, initialRTT<<(i-1)) } for j, r2 := range results { if i == j { continue } require.NotEqual(t, r1.pc, r2.pc) } } // receiving a PATH_RESPONSE for any of the PATH_CHALLENGES completes path validation pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: *results[2].pc}) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("probing should have completed") } // It is valid to probe again results = results[:0] ctx, cancel := context.WithCancel(context.Background()) go func() { errChan <- p.Probe(ctx) }() synctest.Wait() for range 2 { select { case err := <-errChan: require.NoError(t, err) case pc := <-pathChallengeChan: results = append(results, result{pc: &pc, took: time.Since(start)}) case <-time.After(time.Second): t.Fatal("should have received a path challenge") } } // this time, don't receive a PATH_RESPONSE cancel() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, context.Canceled) default: t.Fatal("should have received a context canceled error") } }) } func TestPathManagerOutgoingAbandonPath(t *testing.T) { synctest.Test(t, func(t *testing.T) { connIDs := []protocol.ConnectionID{ protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), } var retiredPaths []pathID pm := newPathManagerOutgoing( func(id pathID) (protocol.ConnectionID, bool) { connID := connIDs[0] connIDs = connIDs[1:] return connID, true }, func(id pathID) { retiredPaths = append(retiredPaths, id) }, func() {}, ) // path abandoned before the PATH_CHALLENGE is sent out p1 := pm.NewPath(&Transport{}, time.Second, func() {}) errChan := make(chan error, 1) go func() { errChan <- p1.Probe(context.Background()) }() // wait for the path to be queued for probing synctest.Wait() require.NoError(t, p1.Close()) // closing the path multiple times is ok require.NoError(t, p1.Close()) require.NoError(t, p1.Close()) _, _, _, ok := pm.NextPathToProbe() require.False(t, ok) synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, ErrPathClosed) default: t.Fatal("should have received a path closed error") } require.Empty(t, retiredPaths) p2 := pm.NewPath(&Transport{}, time.Second, func() {}) go func() { errChan <- p2.Probe(context.Background()) }() // wait for the path to be queued for probing synctest.Wait() connID, f, _, ok := pm.NextPathToProbe() require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), connID) require.NoError(t, p2.Close()) require.Equal(t, []pathID{p2.id}, retiredPaths) pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: f.Frame.(*wire.PathChallengeFrame).Data}) _, _, _, ok = pm.NextPathToProbe() require.False(t, ok) // it's not possible to switch to an abandoned path require.ErrorIs(t, p2.Switch(), ErrPathClosed) }) } quic-go-0.59.0/path_manager_test.go000066400000000000000000000275341513066070600172040ustar00rootroot00000000000000package quic import ( "crypto/rand" "net" "testing" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) // The path is established by receiving a non-probing packet. // The first non-probing packet is received after path validation has completed. // This is the typical scenario when the client initiates connection migration. func TestPathManagerIntentionalMigration(t *testing.T) { connIDs := []protocol.ConnectionID{ protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), protocol.ParseConnectionID([]byte{2, 3, 4, 5, 6, 7, 8, 9}), protocol.ParseConnectionID([]byte{3, 4, 5, 6, 7, 8, 9, 0}), } var retiredConnIDs []protocol.ConnectionID pm := newPathManager( func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true }, func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) }, utils.DefaultLogger, ) now := monotime.Now() connID, frames, shouldSwitch := pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, false, ) require.Equal(t, connIDs[0], connID) require.Len(t, frames, 2) require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame) pc1 := frames[0].Frame.(*wire.PathChallengeFrame) require.NotZero(t, pc1.Data) require.NotEqual(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, pc1.Data) require.IsType(t, &wire.PathResponseFrame{}, frames[1].Frame) require.Equal(t, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, frames[1].Frame.(*wire.PathResponseFrame).Data) require.False(t, shouldSwitch) // receiving another packet for the same path doesn't trigger another PATH_CHALLENGE connID, frames, shouldSwitch = pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, false, ) require.Zero(t, connID) require.Len(t, frames, 0) require.False(t, shouldSwitch) // receiving a packet for a different path triggers another PATH_CHALLENGE addr2 := &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000} connID, frames, shouldSwitch = pm.HandlePacket(addr2, now, nil, false) require.Equal(t, connIDs[1], connID) require.Len(t, frames, 1) require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame) pc2 := frames[0].Frame.(*wire.PathChallengeFrame) require.NotEqual(t, pc1.Data, pc2.Data) require.False(t, shouldSwitch) // acknowledging the PATH_CHALLENGE doesn't confirm the path for _, f := range frames { f.Handler.OnAcked(f.Frame) } connID, frames, shouldSwitch = pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, false, ) require.Zero(t, connID) require.Empty(t, frames) require.False(t, shouldSwitch) // receiving a PATH_RESPONSE for the second path confirms the path pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc2.Data}) connID, frames, shouldSwitch = pm.HandlePacket(addr2, now, nil, false) require.Zero(t, connID) require.Empty(t, frames) require.False(t, shouldSwitch) // no non-probing packet received yet require.Empty(t, retiredConnIDs) // confirming the path doesn't remove other paths connID, frames, shouldSwitch = pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, false, ) require.Zero(t, connID) require.Empty(t, frames) require.False(t, shouldSwitch) // now receive a non-probing packet for the new path connID, frames, shouldSwitch = pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000}, now, nil, true, ) require.Zero(t, connID) require.Empty(t, frames) require.True(t, shouldSwitch) // now switch to the new path pm.SwitchToPath(&net.UDPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 1000}) // switching to the path removes other paths connID, frames, shouldSwitch = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, false) require.Equal(t, connIDs[2], connID) require.NotEmpty(t, frames) require.NotEqual(t, frames[0].Frame.(*wire.PathChallengeFrame).Data, pc1.Data) require.False(t, shouldSwitch) require.Equal(t, []protocol.ConnectionID{connIDs[0]}, retiredConnIDs) } func TestPathManagerMultipleProbes(t *testing.T) { connIDs := []protocol.ConnectionID{ protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), } pm := newPathManager( func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true }, func(id pathID) {}, utils.DefaultLogger, ) now := monotime.Now() // first receive a packet without a PATH_CHALLENGE connID, frames, shouldSwitch := pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, false, ) require.Equal(t, connIDs[0], connID) require.Len(t, frames, 1) require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame) require.False(t, shouldSwitch) // now receive a packet on the same path with a PATH_CHALLENGE connID, frames, shouldSwitch = pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, false, ) require.Equal(t, connIDs[0], connID) require.Len(t, frames, 1) require.Equal(t, &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, frames[0].Frame) require.False(t, shouldSwitch) // now receive another packet on the same path with a PATH_RESPONSE connID, frames, shouldSwitch = pm.HandlePacket( &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, &wire.PathChallengeFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}, false, ) require.Equal(t, connIDs[0], connID) require.Len(t, frames, 1) require.Equal(t, &wire.PathResponseFrame{Data: [8]byte{8, 7, 6, 5, 4, 3, 2, 1}}, frames[0].Frame) require.False(t, shouldSwitch) // lose the response packet frames[0].Handler.OnLost(frames[0].Frame) } // The first packet received on the new path is already a non-probing packet. // We still need to validate the new path, but we can then switch over immediately. // This is the typical scenario when a NAT rebinding happens. func TestPathManagerNATRebinding(t *testing.T) { connIDs := []protocol.ConnectionID{ protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), } var retiredConnIDs []protocol.ConnectionID pm := newPathManager( func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true }, func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) }, utils.DefaultLogger, ) now := monotime.Now() connID, frames, shouldSwitch := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, true) require.Equal(t, connIDs[0], connID) require.Len(t, frames, 1) require.IsType(t, &wire.PathChallengeFrame{}, frames[0].Frame) pc1 := frames[0].Frame.(*wire.PathChallengeFrame) require.NotZero(t, pc1.Data) require.False(t, shouldSwitch) // receiving a PATH_RESPONSE for the second path confirms the path pm.HandlePathResponseFrame(&wire.PathResponseFrame{Data: pc1.Data}) // we now switch to the new path, as soon as the next packet on that path is received connID, frames, shouldSwitch = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}, now, nil, false) require.Zero(t, connID) require.Empty(t, frames) require.True(t, shouldSwitch) } func TestPathManagerLimits(t *testing.T) { var connIDs []protocol.ConnectionID for range 2*maxPaths + 2 { b := make([]byte, 8) rand.Read(b) connIDs = append(connIDs, protocol.ParseConnectionID(b)) } var retiredConnIDs []protocol.ConnectionID pm := newPathManager( func(id pathID) (protocol.ConnectionID, bool) { return connIDs[id], true }, func(id pathID) { retiredConnIDs = append(retiredConnIDs, connIDs[id]) }, utils.DefaultLogger, ) now := monotime.Now() firstPathTime := now var firstPathConnID protocol.ConnectionID require.Greater(t, pathTimeout, maxPaths*time.Second) for i := range maxPaths { connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000 + i}, now, nil, true) require.NotEmpty(t, frames) require.Equal(t, connIDs[i], connID) if i == 0 { firstPathConnID = connID } now = now.Add(time.Second) } // the maximum number of paths is already being probed now = firstPathTime.Add(pathTimeout).Add(-time.Nanosecond) connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}, now, nil, true) require.Zero(t, connID) require.Empty(t, frames) // receiving another packet after the pathTimeout of the first path evicts the first path now = firstPathTime.Add(pathTimeout) connIDIndex := maxPaths connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000 + maxPaths}, now, nil, true) require.NotEmpty(t, frames) require.Equal(t, connIDs[connIDIndex], connID) require.Equal(t, []protocol.ConnectionID{firstPathConnID}, retiredConnIDs) connIDIndex++ // switching to a new path frees is up all paths var f1 []ackhandler.Frame pm.SwitchToPath(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1000}) for i := range maxPaths { connID, frames, _ := pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3000 + i}, now, nil, true) if i == 0 { f1 = frames } require.NotEmpty(t, frames) require.Equal(t, connIDs[connIDIndex], connID) connIDIndex++ } // again, the maximum number of paths is already being probed connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 2000}, now, nil, true) require.Zero(t, connID) require.Empty(t, frames) // losing the frame removes this path f1[0].Handler.OnLost(f1[0].Frame) // we can open exactly one more path connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4000}, now, nil, true) require.NotEmpty(t, frames) require.Equal(t, connIDs[connIDIndex], connID) connID, frames, _ = pm.HandlePacket(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4001}, now, nil, true) require.Zero(t, connID) require.Empty(t, frames) } type mockAddr struct { str string } func (a *mockAddr) Network() string { return "mock" } func (a *mockAddr) String() string { return a.str } func TestAddrsEqual(t *testing.T) { tests := []struct { name string addr1 net.Addr addr2 net.Addr expected bool }{ { name: "nil addresses", addr1: nil, addr2: nil, expected: false, }, { name: "one nil address", addr1: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, addr2: nil, expected: false, }, { name: "same IPv4 addresses", addr1: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, addr2: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, expected: true, }, { name: "different IPv4 addresses", addr1: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, addr2: &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1), Port: 1234}, expected: false, }, { name: "different ports", addr1: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, addr2: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}, expected: false, }, { name: "same IPv6 addresses", addr1: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, addr2: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, expected: true, }, { name: "different IPv6 addresses", addr1: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, addr2: &net.UDPAddr{IP: net.ParseIP("2001:db8::2"), Port: 1234}, expected: false, }, { name: "non-UDP addresses with same string representation", addr1: &mockAddr{str: "192.0.2.1:1234"}, addr2: &mockAddr{str: "192.0.2.1:1234"}, expected: true, }, { name: "non-UDP addresses with different string representation", addr1: &mockAddr{str: "192.0.2.1:1234"}, addr2: &mockAddr{str: "192.0.2.2:1234"}, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := addrsEqual(tt.addr1, tt.addr2) require.Equal(t, tt.expected, result) }) } } quic-go-0.59.0/qlog/000077500000000000000000000000001513066070600141175ustar00rootroot00000000000000quic-go-0.59.0/qlog/benchmark_test.go000066400000000000000000000047471513066070600174530ustar00rootroot00000000000000package qlog import ( "io" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlogwriter" ) type nopWriteCloserImpl struct{ io.Writer } func (nopWriteCloserImpl) Close() error { return nil } func nopWriteCloser(w io.Writer) io.WriteCloser { return &nopWriteCloserImpl{Writer: w} } // BenchmarkConnectionTracing aims to benchmark a somewhat realistic connection that sends and receives packets. func BenchmarkConnectionTracing(b *testing.B) { b.ReportAllocs() srcConnID := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) trace := qlogwriter.NewConnectionFileSeq( nopWriteCloser(io.Discard), false, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), []string{EventSchema}, ) go trace.Run() tracer := trace.AddProducer() b.Cleanup(func() { tracer.Close() }) rttStats := utils.NewRTTStats() rttStats.UpdateRTT(1337*time.Millisecond, 0) rttStats.UpdateRTT(1000*time.Millisecond, 10*time.Millisecond) rttStats.UpdateRTT(800*time.Millisecond, 100*time.Millisecond) var i int for b.Loop() { i++ tracer.RecordEvent(&PacketSent{ Header: PacketHeader{ PacketType: PacketType1RTT, PacketNumber: 1234 + protocol.PacketNumber(i), KeyPhaseBit: KeyPhaseZero, DestConnectionID: srcConnID, }, Raw: RawInfo{Length: 1337}, ECN: ECT0, Frames: []Frame{ {Frame: &AckFrame{AckRanges: []wire.AckRange{{Largest: 12345 + protocol.PacketNumber(2*i), Smallest: 1234 + protocol.PacketNumber(i)}}}}, {Frame: &MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987 + protocol.ByteCount(i)}}, }, }) tracer.RecordEvent(&MetricsUpdated{ MinRTT: rttStats.MinRTT(), SmoothedRTT: rttStats.SmoothedRTT(), LatestRTT: rttStats.LatestRTT(), RTTVariance: rttStats.MeanDeviation(), CongestionWindow: int(12345 + protocol.ByteCount(2*i)), BytesInFlight: int(12345 + protocol.ByteCount(i)), PacketsInFlight: i, }) if i%2 == 0 { tracer.RecordEvent(&PacketReceived{ Header: PacketHeader{ PacketType: PacketType1RTT, PacketNumber: 1337 + protocol.PacketNumber(i), KeyPhaseBit: KeyPhaseOne, DestConnectionID: srcConnID, }, Raw: RawInfo{Length: 1337}, ECN: ECT0, Frames: []Frame{ {Frame: &StreamFrame{StreamID: 123, Offset: int64(1234 + protocol.ByteCount(100*i)), Length: 100, Fin: true}}, }, }) } } } quic-go-0.59.0/qlog/event.go000066400000000000000000000631641513066070600156010ustar00rootroot00000000000000package qlog import ( "fmt" "net/netip" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) func milliseconds(dur time.Duration) float64 { return float64(dur.Nanoseconds()) / 1e6 } type encoderHelper struct { enc *jsontext.Encoder err error } func (h *encoderHelper) WriteToken(t jsontext.Token) { if h.err != nil { return } h.err = h.enc.WriteToken(t) } type versions []Version func (v versions) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginArray) for _, e := range v { h.WriteToken(jsontext.String(fmt.Sprintf("%x", uint32(e)))) } h.WriteToken(jsontext.EndArray) return h.err } type RawInfo struct { Length int // full packet length, including header and AEAD authentication tag PayloadLength int // length of the packet payload, excluding AEAD tag } func (i RawInfo) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("length")) h.WriteToken(jsontext.Uint(uint64(i.Length))) if i.PayloadLength != 0 { h.WriteToken(jsontext.String("payload_length")) h.WriteToken(jsontext.Uint(uint64(i.PayloadLength))) } h.WriteToken(jsontext.EndObject) return h.err } type PathEndpointInfo struct { IPv4 netip.AddrPort IPv6 netip.AddrPort } func (p PathEndpointInfo) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if p.IPv4.IsValid() { h.WriteToken(jsontext.String("ip_v4")) h.WriteToken(jsontext.String(p.IPv4.Addr().String())) h.WriteToken(jsontext.String("port_v4")) h.WriteToken(jsontext.Int(int64(p.IPv4.Port()))) } if p.IPv6.IsValid() { h.WriteToken(jsontext.String("ip_v6")) h.WriteToken(jsontext.String(p.IPv6.Addr().String())) h.WriteToken(jsontext.String("port_v6")) h.WriteToken(jsontext.Int(int64(p.IPv6.Port()))) } h.WriteToken(jsontext.EndObject) return h.err } type StartedConnection struct { Local PathEndpointInfo Remote PathEndpointInfo } func (e StartedConnection) Name() string { return "transport:connection_started" } func (e StartedConnection) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("local")) if err := e.Local.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("remote")) if err := e.Remote.encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } type VersionInformation struct { ClientVersions, ServerVersions []Version ChosenVersion Version } func (e VersionInformation) Name() string { return "transport:version_information" } func (e VersionInformation) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if len(e.ClientVersions) > 0 { h.WriteToken(jsontext.String("client_versions")) if err := versions(e.ClientVersions).encode(enc); err != nil { return err } } if len(e.ServerVersions) > 0 { h.WriteToken(jsontext.String("server_versions")) if err := versions(e.ServerVersions).encode(enc); err != nil { return err } } h.WriteToken(jsontext.String("chosen_version")) h.WriteToken(jsontext.String(fmt.Sprintf("%x", uint32(e.ChosenVersion)))) h.WriteToken(jsontext.EndObject) return h.err } type ConnectionClosed struct { Initiator Initiator ConnectionError *TransportErrorCode ApplicationError *ApplicationErrorCode Reason string Trigger ConnectionCloseTrigger } func (e ConnectionClosed) Name() string { return "transport:connection_closed" } func (e ConnectionClosed) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("initiator")) h.WriteToken(jsontext.String(string(e.Initiator))) if e.ConnectionError != nil { h.WriteToken(jsontext.String("connection_error")) if e.ConnectionError.IsCryptoError() { h.WriteToken(jsontext.String(fmt.Sprintf("crypto_error_%#x", uint16(*e.ConnectionError)))) } else { switch *e.ConnectionError { case qerr.NoError: h.WriteToken(jsontext.String("no_error")) case qerr.InternalError: h.WriteToken(jsontext.String("internal_error")) case qerr.ConnectionRefused: h.WriteToken(jsontext.String("connection_refused")) case qerr.FlowControlError: h.WriteToken(jsontext.String("flow_control_error")) case qerr.StreamLimitError: h.WriteToken(jsontext.String("stream_limit_error")) case qerr.StreamStateError: h.WriteToken(jsontext.String("stream_state_error")) case qerr.FinalSizeError: h.WriteToken(jsontext.String("final_size_error")) case qerr.FrameEncodingError: h.WriteToken(jsontext.String("frame_encoding_error")) case qerr.TransportParameterError: h.WriteToken(jsontext.String("transport_parameter_error")) case qerr.ConnectionIDLimitError: h.WriteToken(jsontext.String("connection_id_limit_error")) case qerr.ProtocolViolation: h.WriteToken(jsontext.String("protocol_violation")) case qerr.InvalidToken: h.WriteToken(jsontext.String("invalid_token")) case qerr.ApplicationErrorErrorCode: h.WriteToken(jsontext.String("application_error")) case qerr.CryptoBufferExceeded: h.WriteToken(jsontext.String("crypto_buffer_exceeded")) case qerr.KeyUpdateError: h.WriteToken(jsontext.String("key_update_error")) case qerr.AEADLimitReached: h.WriteToken(jsontext.String("aead_limit_reached")) case qerr.NoViablePathError: h.WriteToken(jsontext.String("no_viable_path")) default: h.WriteToken(jsontext.String("unknown")) h.WriteToken(jsontext.String("error_code")) h.WriteToken(jsontext.Uint(uint64(*e.ConnectionError))) } } } if e.ApplicationError != nil { h.WriteToken(jsontext.String("application_error")) h.WriteToken(jsontext.String("unknown")) h.WriteToken(jsontext.String("error_code")) h.WriteToken(jsontext.Uint(uint64(*e.ApplicationError))) } if e.ConnectionError != nil || e.ApplicationError != nil { h.WriteToken(jsontext.String("reason")) h.WriteToken(jsontext.String(e.Reason)) } if e.Trigger != "" { h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(string(e.Trigger))) } h.WriteToken(jsontext.EndObject) return h.err } type PacketSent struct { Header PacketHeader Raw RawInfo DatagramID DatagramID Frames []Frame ECN ECN IsCoalesced bool Trigger string SupportedVersions []Version } func (e PacketSent) Name() string { return "transport:packet_sent" } func (e PacketSent) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } if e.DatagramID != 0 { h.WriteToken(jsontext.String("datagram_id")) h.WriteToken(jsontext.Uint(uint64(e.DatagramID))) } if len(e.Frames) > 0 { h.WriteToken(jsontext.String("frames")) if err := frames(e.Frames).encode(enc); err != nil { return err } } if e.IsCoalesced { h.WriteToken(jsontext.String("is_coalesced")) h.WriteToken(jsontext.True) } if e.ECN != ECNUnsupported { h.WriteToken(jsontext.String("ecn")) h.WriteToken(jsontext.String(string(e.ECN))) } if e.Trigger != "" { h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(e.Trigger)) } h.WriteToken(jsontext.EndObject) return h.err } type PacketReceived struct { Header PacketHeader Raw RawInfo DatagramID DatagramID Frames []Frame ECN ECN IsCoalesced bool Trigger string } func (e PacketReceived) Name() string { return "transport:packet_received" } func (e PacketReceived) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } if e.DatagramID != 0 { h.WriteToken(jsontext.String("datagram_id")) h.WriteToken(jsontext.Uint(uint64(e.DatagramID))) } if len(e.Frames) > 0 { h.WriteToken(jsontext.String("frames")) if err := frames(e.Frames).encode(enc); err != nil { return err } } if e.IsCoalesced { h.WriteToken(jsontext.String("is_coalesced")) h.WriteToken(jsontext.True) } if e.ECN != ECNUnsupported { h.WriteToken(jsontext.String("ecn")) h.WriteToken(jsontext.String(string(e.ECN))) } if e.Trigger != "" { h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(e.Trigger)) } h.WriteToken(jsontext.EndObject) return h.err } type VersionNegotiationReceived struct { Header PacketHeaderVersionNegotiation SupportedVersions []Version } func (e VersionNegotiationReceived) Name() string { return "transport:packet_received" } func (e VersionNegotiationReceived) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("supported_versions")) if err := versions(e.SupportedVersions).encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } type VersionNegotiationSent struct { Header PacketHeaderVersionNegotiation SupportedVersions []Version } func (e VersionNegotiationSent) Name() string { return "transport:packet_sent" } func (e VersionNegotiationSent) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("supported_versions")) if err := versions(e.SupportedVersions).encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } type PacketBuffered struct { Header PacketHeader Raw RawInfo DatagramID DatagramID } func (e PacketBuffered) Name() string { return "transport:packet_buffered" } func (e PacketBuffered) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } if e.DatagramID != 0 { h.WriteToken(jsontext.String("datagram_id")) h.WriteToken(jsontext.Uint(uint64(e.DatagramID))) } h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String("keys_unavailable")) h.WriteToken(jsontext.EndObject) return h.err } // PacketDropped is the transport:packet_dropped event. type PacketDropped struct { Header PacketHeader Raw RawInfo DatagramID DatagramID Trigger PacketDropReason } func (e PacketDropped) Name() string { return "transport:packet_dropped" } func (e PacketDropped) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("raw")) if err := e.Raw.encode(enc); err != nil { return err } if e.DatagramID != 0 { h.WriteToken(jsontext.String("datagram_id")) h.WriteToken(jsontext.Uint(uint64(e.DatagramID))) } h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(string(e.Trigger))) h.WriteToken(jsontext.EndObject) return h.err } type MTUUpdated struct { Value int Done bool } func (e MTUUpdated) Name() string { return "recovery:mtu_updated" } func (e MTUUpdated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("mtu")) h.WriteToken(jsontext.Uint(uint64(e.Value))) h.WriteToken(jsontext.String("done")) h.WriteToken(jsontext.Bool(e.Done)) h.WriteToken(jsontext.EndObject) return h.err } // MetricsUpdated logs RTT and congestion metrics as defined in the // recovery:metrics_updated event. // The PTO count is logged via PTOCountUpdated. type MetricsUpdated struct { MinRTT time.Duration SmoothedRTT time.Duration LatestRTT time.Duration RTTVariance time.Duration CongestionWindow int BytesInFlight int PacketsInFlight int } func (e MetricsUpdated) Name() string { return "recovery:metrics_updated" } func (e MetricsUpdated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if e.MinRTT != 0 { h.WriteToken(jsontext.String("min_rtt")) h.WriteToken(jsontext.Float(milliseconds(e.MinRTT))) } if e.SmoothedRTT != 0 { h.WriteToken(jsontext.String("smoothed_rtt")) h.WriteToken(jsontext.Float(milliseconds(e.SmoothedRTT))) } if e.LatestRTT != 0 { h.WriteToken(jsontext.String("latest_rtt")) h.WriteToken(jsontext.Float(milliseconds(e.LatestRTT))) } if e.RTTVariance != 0 { h.WriteToken(jsontext.String("rtt_variance")) h.WriteToken(jsontext.Float(milliseconds(e.RTTVariance))) } if e.CongestionWindow != 0 { h.WriteToken(jsontext.String("congestion_window")) h.WriteToken(jsontext.Uint(uint64(e.CongestionWindow))) } if e.BytesInFlight != 0 { h.WriteToken(jsontext.String("bytes_in_flight")) h.WriteToken(jsontext.Uint(uint64(e.BytesInFlight))) } if e.PacketsInFlight != 0 { h.WriteToken(jsontext.String("packets_in_flight")) h.WriteToken(jsontext.Uint(uint64(e.PacketsInFlight))) } h.WriteToken(jsontext.EndObject) return h.err } // PTOCountUpdated logs the pto_count value of the // recovery:metrics_updated event. type PTOCountUpdated struct { PTOCount uint32 } func (e PTOCountUpdated) Name() string { return "recovery:metrics_updated" } func (e PTOCountUpdated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("pto_count")) h.WriteToken(jsontext.Uint(uint64(e.PTOCount))) h.WriteToken(jsontext.EndObject) return h.err } type PacketLost struct { Header PacketHeader Trigger PacketLossReason } func (e PacketLost) Name() string { return "recovery:packet_lost" } func (e PacketLost) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("header")) if err := e.Header.encode(enc); err != nil { return err } h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(string(e.Trigger))) h.WriteToken(jsontext.EndObject) return h.err } type SpuriousLoss struct { EncryptionLevel protocol.EncryptionLevel PacketNumber protocol.PacketNumber PacketReordering uint64 TimeReordering time.Duration } func (e SpuriousLoss) Name() string { return "recovery:spurious_loss" } func (e SpuriousLoss) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("packet_number_space")) h.WriteToken(jsontext.String(encLevelToPacketNumberSpace(e.EncryptionLevel))) h.WriteToken(jsontext.String("packet_number")) h.WriteToken(jsontext.Uint(uint64(e.PacketNumber))) h.WriteToken(jsontext.String("reordering_packets")) h.WriteToken(jsontext.Uint(e.PacketReordering)) h.WriteToken(jsontext.String("reordering_time")) h.WriteToken(jsontext.Float(milliseconds(e.TimeReordering))) h.WriteToken(jsontext.EndObject) return h.err } type KeyUpdated struct { Trigger KeyUpdateTrigger KeyType KeyType KeyPhase KeyPhase // only set for 1-RTT keys // we don't log the keys here, so we don't need `old` and `new`. } func (e KeyUpdated) Name() string { return "security:key_updated" } func (e KeyUpdated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(string(e.Trigger))) h.WriteToken(jsontext.String("key_type")) h.WriteToken(jsontext.String(string(e.KeyType))) if e.KeyType == KeyTypeClient1RTT || e.KeyType == KeyTypeServer1RTT { h.WriteToken(jsontext.String("key_phase")) h.WriteToken(jsontext.Uint(uint64(e.KeyPhase))) } h.WriteToken(jsontext.EndObject) return h.err } type KeyDiscarded struct { KeyType KeyType KeyPhase KeyPhase // only set for 1-RTT keys } func (e KeyDiscarded) Name() string { return "security:key_discarded" } func (e KeyDiscarded) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if e.KeyType != KeyTypeClient1RTT && e.KeyType != KeyTypeServer1RTT { h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String("tls")) } h.WriteToken(jsontext.String("key_type")) h.WriteToken(jsontext.String(string(e.KeyType))) if e.KeyType == KeyTypeClient1RTT || e.KeyType == KeyTypeServer1RTT { h.WriteToken(jsontext.String("key_phase")) h.WriteToken(jsontext.Uint(uint64(e.KeyPhase))) } h.WriteToken(jsontext.EndObject) return h.err } type ParametersSet struct { Restore bool Initiator Initiator SentBy protocol.Perspective OriginalDestinationConnectionID protocol.ConnectionID InitialSourceConnectionID protocol.ConnectionID RetrySourceConnectionID *protocol.ConnectionID StatelessResetToken *protocol.StatelessResetToken DisableActiveMigration bool MaxIdleTimeout time.Duration MaxUDPPayloadSize protocol.ByteCount AckDelayExponent uint8 MaxAckDelay time.Duration ActiveConnectionIDLimit uint64 InitialMaxData protocol.ByteCount InitialMaxStreamDataBidiLocal protocol.ByteCount InitialMaxStreamDataBidiRemote protocol.ByteCount InitialMaxStreamDataUni protocol.ByteCount InitialMaxStreamsBidi int64 InitialMaxStreamsUni int64 PreferredAddress *PreferredAddress MaxDatagramFrameSize protocol.ByteCount EnableResetStreamAt bool } func (e ParametersSet) Name() string { if e.Restore { return "transport:parameters_restored" } return "transport:parameters_set" } func (e ParametersSet) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if !e.Restore { h.WriteToken(jsontext.String("initiator")) h.WriteToken(jsontext.String(string(e.Initiator))) if e.SentBy == protocol.PerspectiveServer { h.WriteToken(jsontext.String("original_destination_connection_id")) h.WriteToken(jsontext.String(e.OriginalDestinationConnectionID.String())) if e.StatelessResetToken != nil { h.WriteToken(jsontext.String("stateless_reset_token")) h.WriteToken(jsontext.String(fmt.Sprintf("%x", e.StatelessResetToken[:]))) } if e.RetrySourceConnectionID != nil { h.WriteToken(jsontext.String("retry_source_connection_id")) h.WriteToken(jsontext.String((*e.RetrySourceConnectionID).String())) } } h.WriteToken(jsontext.String("initial_source_connection_id")) h.WriteToken(jsontext.String(e.InitialSourceConnectionID.String())) } h.WriteToken(jsontext.String("disable_active_migration")) h.WriteToken(jsontext.Bool(e.DisableActiveMigration)) if e.MaxIdleTimeout != 0 { h.WriteToken(jsontext.String("max_idle_timeout")) h.WriteToken(jsontext.Float(milliseconds(e.MaxIdleTimeout))) } if e.MaxUDPPayloadSize != 0 { h.WriteToken(jsontext.String("max_udp_payload_size")) h.WriteToken(jsontext.Int(int64(e.MaxUDPPayloadSize))) } if e.AckDelayExponent != 0 { h.WriteToken(jsontext.String("ack_delay_exponent")) h.WriteToken(jsontext.Uint(uint64(e.AckDelayExponent))) } if e.MaxAckDelay != 0 { h.WriteToken(jsontext.String("max_ack_delay")) h.WriteToken(jsontext.Float(milliseconds(e.MaxAckDelay))) } if e.ActiveConnectionIDLimit != 0 { h.WriteToken(jsontext.String("active_connection_id_limit")) h.WriteToken(jsontext.Uint(e.ActiveConnectionIDLimit)) } if e.InitialMaxData != 0 { h.WriteToken(jsontext.String("initial_max_data")) h.WriteToken(jsontext.Int(int64(e.InitialMaxData))) } if e.InitialMaxStreamDataBidiLocal != 0 { h.WriteToken(jsontext.String("initial_max_stream_data_bidi_local")) h.WriteToken(jsontext.Int(int64(e.InitialMaxStreamDataBidiLocal))) } if e.InitialMaxStreamDataBidiRemote != 0 { h.WriteToken(jsontext.String("initial_max_stream_data_bidi_remote")) h.WriteToken(jsontext.Int(int64(e.InitialMaxStreamDataBidiRemote))) } if e.InitialMaxStreamDataUni != 0 { h.WriteToken(jsontext.String("initial_max_stream_data_uni")) h.WriteToken(jsontext.Int(int64(e.InitialMaxStreamDataUni))) } if e.InitialMaxStreamsBidi != 0 { h.WriteToken(jsontext.String("initial_max_streams_bidi")) h.WriteToken(jsontext.Int(e.InitialMaxStreamsBidi)) } if e.InitialMaxStreamsUni != 0 { h.WriteToken(jsontext.String("initial_max_streams_uni")) h.WriteToken(jsontext.Int(e.InitialMaxStreamsUni)) } if e.PreferredAddress != nil { h.WriteToken(jsontext.String("preferred_address")) if err := e.PreferredAddress.encode(enc); err != nil { return err } } if e.MaxDatagramFrameSize != protocol.InvalidByteCount { h.WriteToken(jsontext.String("max_datagram_frame_size")) h.WriteToken(jsontext.Int(int64(e.MaxDatagramFrameSize))) } if e.EnableResetStreamAt { h.WriteToken(jsontext.String("reset_stream_at")) h.WriteToken(jsontext.True) } h.WriteToken(jsontext.EndObject) return h.err } type PreferredAddress struct { IPv4, IPv6 netip.AddrPort ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } func (a PreferredAddress) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) if a.IPv4.IsValid() { h.WriteToken(jsontext.String("ip_v4")) h.WriteToken(jsontext.String(a.IPv4.Addr().String())) h.WriteToken(jsontext.String("port_v4")) h.WriteToken(jsontext.Uint(uint64(a.IPv4.Port()))) } if a.IPv6.IsValid() { h.WriteToken(jsontext.String("ip_v6")) h.WriteToken(jsontext.String(a.IPv6.Addr().String())) h.WriteToken(jsontext.String("port_v6")) h.WriteToken(jsontext.Uint(uint64(a.IPv6.Port()))) } h.WriteToken(jsontext.String("connection_id")) h.WriteToken(jsontext.String(a.ConnectionID.String())) h.WriteToken(jsontext.String("stateless_reset_token")) h.WriteToken(jsontext.String(fmt.Sprintf("%x", a.StatelessResetToken))) h.WriteToken(jsontext.EndObject) return h.err } type LossTimerUpdated struct { Type LossTimerUpdateType TimerType TimerType EncLevel EncryptionLevel Time time.Time } func (e LossTimerUpdated) Name() string { return "recovery:loss_timer_updated" } func (e LossTimerUpdated) Encode(enc *jsontext.Encoder, t time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("event_type")) h.WriteToken(jsontext.String(string(e.Type))) h.WriteToken(jsontext.String("timer_type")) h.WriteToken(jsontext.String(string(e.TimerType))) h.WriteToken(jsontext.String("packet_number_space")) h.WriteToken(jsontext.String(encLevelToPacketNumberSpace(e.EncLevel))) if e.Type == LossTimerUpdateTypeSet { h.WriteToken(jsontext.String("delta")) h.WriteToken(jsontext.Float(milliseconds(e.Time.Sub(t)))) } h.WriteToken(jsontext.EndObject) return h.err } type eventLossTimerCanceled struct{} func (e eventLossTimerCanceled) Name() string { return "recovery:loss_timer_updated" } func (e eventLossTimerCanceled) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("event_type")) h.WriteToken(jsontext.String("cancelled")) h.WriteToken(jsontext.EndObject) return h.err } type CongestionStateUpdated struct { State CongestionState } func (e CongestionStateUpdated) Name() string { return "recovery:congestion_state_updated" } func (e CongestionStateUpdated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("new")) h.WriteToken(jsontext.String(e.State.String())) h.WriteToken(jsontext.EndObject) return h.err } type ECNStateUpdated struct { State ECNState Trigger string } func (e ECNStateUpdated) Name() string { return "recovery:ecn_state_updated" } func (e ECNStateUpdated) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("new")) h.WriteToken(jsontext.String(string(e.State))) if e.Trigger != "" { h.WriteToken(jsontext.String("trigger")) h.WriteToken(jsontext.String(e.Trigger)) } h.WriteToken(jsontext.EndObject) return h.err } type ALPNInformation struct { ChosenALPN string } func (e ALPNInformation) Name() string { return "transport:alpn_information" } func (e ALPNInformation) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("chosen_alpn")) h.WriteToken(jsontext.String(e.ChosenALPN)) h.WriteToken(jsontext.EndObject) return h.err } // DebugEvent is a generic event that can be used to log arbitrary messages. type DebugEvent struct { EventName string Message string } func (e DebugEvent) Name() string { if e.EventName == "" { return "transport:debug" } return fmt.Sprintf("transport:%s", e.EventName) } func (e DebugEvent) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("message")) h.WriteToken(jsontext.String(e.Message)) h.WriteToken(jsontext.EndObject) return h.err } quic-go-0.59.0/qlog/event_test.go000066400000000000000000000735401513066070600166370ustar00rootroot00000000000000package qlog import ( "bytes" "encoding/json" "net/netip" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/require" ) func testEventEncoding(t *testing.T, ev qlogwriter.Event) (string, map[string]any) { t.Helper() var buf bytes.Buffer synctest.Test(t, func(t *testing.T) { tr := qlogwriter.NewConnectionFileSeq( nopWriteCloser(&buf), true, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), []string{EventSchema}, ) go tr.Run() producer := tr.AddProducer() synctest.Wait() time.Sleep(42 * time.Second) producer.RecordEvent(ev) producer.Close() }) return decode(t, buf.String()) } func decode(t *testing.T, data string) (string, map[string]any) { t.Helper() var result map[string]any lines := bytes.Split([]byte(data), []byte{'\n'}) require.Len(t, lines, 3) // the first line is the trace header, the second line is the event, the third line is empty require.Empty(t, lines[2]) require.Equal(t, qlogwriter.RecordSeparator, lines[1][0], "expected record separator at start of line") require.NoError(t, json.Unmarshal(lines[1][1:], &result)) require.Equal(t, 42*time.Second, time.Duration(result["time"].(float64)*1e6)*time.Nanosecond) return result["name"].(string), result["data"].(map[string]any) } func TestStartedConnection(t *testing.T) { var localInfo, remoteInfo PathEndpointInfo localInfo.IPv4 = netip.AddrPortFrom(netip.AddrFrom4([4]byte{192, 168, 13, 37}), 42) ip, err := netip.ParseAddr("2001:db8::1") require.NoError(t, err) remoteInfo.IPv6 = netip.AddrPortFrom(ip, 24) name, ev := testEventEncoding(t, &StartedConnection{ Local: localInfo, Remote: remoteInfo, }) require.Equal(t, "transport:connection_started", name) local, ok := ev["local"].(map[string]any) require.True(t, ok) require.Equal(t, "192.168.13.37", local["ip_v4"]) require.Equal(t, float64(42), local["port_v4"]) remote, ok := ev["remote"].(map[string]any) require.True(t, ok) require.Equal(t, "2001:db8::1", remote["ip_v6"]) require.Equal(t, float64(24), remote["port_v6"]) } func TestVersionInformation(t *testing.T) { name, ev := testEventEncoding(t, &VersionInformation{ChosenVersion: 0x1337}) require.Equal(t, "transport:version_information", name) require.Len(t, ev, 1) require.Equal(t, "1337", ev["chosen_version"]) } func TestVersionInformationWithNegotiation(t *testing.T) { name, ev := testEventEncoding(t, &VersionInformation{ ChosenVersion: 0x1337, ClientVersions: []Version{1, 2, 3}, ServerVersions: []Version{4, 5, 6}, }) require.Equal(t, "transport:version_information", name) require.Len(t, ev, 3) require.Equal(t, "1337", ev["chosen_version"]) require.Equal(t, []any{"1", "2", "3"}, ev["client_versions"]) require.Equal(t, []any{"4", "5", "6"}, ev["server_versions"]) } func TestIdleTimeouts(t *testing.T) { name, ev := testEventEncoding(t, &ConnectionClosed{ Initiator: InitiatorLocal, Trigger: ConnectionCloseTriggerIdleTimeout, }) require.Equal(t, "transport:connection_closed", name) require.Len(t, ev, 2) require.Equal(t, "local", ev["initiator"]) require.Equal(t, "idle_timeout", ev["trigger"]) } func TestReceivedStatelessResetPacket(t *testing.T) { name, ev := testEventEncoding(t, &ConnectionClosed{ Initiator: InitiatorRemote, Trigger: ConnectionCloseTriggerStatelessReset, }) require.Equal(t, "transport:connection_closed", name) require.Len(t, ev, 2) require.Equal(t, "remote", ev["initiator"]) require.Equal(t, "stateless_reset", ev["trigger"]) } func TestVersionNegotiationFailure(t *testing.T) { name, ev := testEventEncoding(t, &ConnectionClosed{ Initiator: InitiatorLocal, Trigger: ConnectionCloseTriggerVersionMismatch, }) require.Equal(t, "transport:connection_closed", name) require.Len(t, ev, 2) require.Equal(t, "local", ev["initiator"]) require.Equal(t, "version_mismatch", ev["trigger"]) } func TestApplicationErrors(t *testing.T) { code := qerr.ApplicationErrorCode(1337) name, ev := testEventEncoding(t, &ConnectionClosed{ Initiator: InitiatorRemote, ApplicationError: &code, Reason: "foobar", }) require.Equal(t, "transport:connection_closed", name) require.Len(t, ev, 4) require.Equal(t, "remote", ev["initiator"]) require.Equal(t, "unknown", ev["application_error"]) require.Equal(t, float64(1337), ev["error_code"]) require.Equal(t, "foobar", ev["reason"]) } func TestTransportErrors(t *testing.T) { tests := []struct { code qerr.TransportErrorCode want string }{ {qerr.NoError, "no_error"}, {qerr.InternalError, "internal_error"}, {qerr.ConnectionRefused, "connection_refused"}, {qerr.FlowControlError, "flow_control_error"}, {qerr.StreamLimitError, "stream_limit_error"}, {qerr.StreamStateError, "stream_state_error"}, {qerr.FinalSizeError, "final_size_error"}, {qerr.FrameEncodingError, "frame_encoding_error"}, {qerr.TransportParameterError, "transport_parameter_error"}, {qerr.ConnectionIDLimitError, "connection_id_limit_error"}, {qerr.ProtocolViolation, "protocol_violation"}, {qerr.InvalidToken, "invalid_token"}, {qerr.ApplicationErrorErrorCode, "application_error"}, {qerr.CryptoBufferExceeded, "crypto_buffer_exceeded"}, {qerr.KeyUpdateError, "key_update_error"}, {qerr.AEADLimitReached, "aead_limit_reached"}, {qerr.NoViablePathError, "no_viable_path"}, } for _, tt := range tests { t.Run(tt.want, func(t *testing.T) { code := tt.code name, ev := testEventEncoding(t, &ConnectionClosed{ Initiator: InitiatorLocal, ConnectionError: &code, Reason: "foobar", }) require.Equal(t, "transport:connection_closed", name) require.Equal(t, "local", ev["initiator"]) require.Equal(t, tt.want, ev["connection_error"]) require.Equal(t, "foobar", ev["reason"]) require.NotContains(t, ev, "error_code") }) } } func TestTransportCryptoError(t *testing.T) { code := qerr.TransportErrorCode(0x100 + 0x2a) name, ev := testEventEncoding(t, &ConnectionClosed{ Initiator: InitiatorLocal, ConnectionError: &code, Reason: "foobar", }) require.Equal(t, "transport:connection_closed", name) require.Equal(t, "local", ev["initiator"]) require.Equal(t, "crypto_error_0x12a", ev["connection_error"]) require.Equal(t, "foobar", ev["reason"]) } func TestSentTransportParameters(t *testing.T) { rcid := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) name, ev := testEventEncoding(t, &ParametersSet{ Initiator: InitiatorLocal, SentBy: protocol.PerspectiveServer, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, DisableActiveMigration: true, MaxIdleTimeout: 321 * time.Millisecond, MaxUDPPayloadSize: 1234, AckDelayExponent: 12, MaxAckDelay: 123 * time.Millisecond, ActiveConnectionIDLimit: 7, InitialMaxData: 4000, InitialMaxStreamDataBidiLocal: 1000, InitialMaxStreamDataBidiRemote: 2000, InitialMaxStreamDataUni: 3000, InitialMaxStreamsBidi: 10, InitialMaxStreamsUni: 20, MaxDatagramFrameSize: protocol.InvalidByteCount, EnableResetStreamAt: true, }) require.Equal(t, "transport:parameters_set", name) require.Equal(t, "local", ev["initiator"]) require.Equal(t, "deadc0de", ev["original_destination_connection_id"]) require.Equal(t, "deadbeef", ev["initial_source_connection_id"]) require.Equal(t, "decafbad", ev["retry_source_connection_id"]) require.Equal(t, "112233445566778899aabbccddeeff00", ev["stateless_reset_token"]) require.Equal(t, float64(321), ev["max_idle_timeout"]) require.Equal(t, float64(1234), ev["max_udp_payload_size"]) require.Equal(t, float64(12), ev["ack_delay_exponent"]) require.Equal(t, float64(7), ev["active_connection_id_limit"]) require.Equal(t, float64(4000), ev["initial_max_data"]) require.Equal(t, float64(1000), ev["initial_max_stream_data_bidi_local"]) require.Equal(t, float64(2000), ev["initial_max_stream_data_bidi_remote"]) require.Equal(t, float64(3000), ev["initial_max_stream_data_uni"]) require.Equal(t, float64(10), ev["initial_max_streams_bidi"]) require.Equal(t, float64(20), ev["initial_max_streams_uni"]) require.True(t, ev["reset_stream_at"].(bool)) require.NotContains(t, ev, "preferred_address") require.NotContains(t, ev, "max_datagram_frame_size") } func TestServerTransportParametersWithoutStatelessResetToken(t *testing.T) { name, ev := testEventEncoding(t, &ParametersSet{ Initiator: InitiatorLocal, SentBy: protocol.PerspectiveServer, OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), ActiveConnectionIDLimit: 7, }) require.Equal(t, "transport:parameters_set", name) require.NotContains(t, ev, "stateless_reset_token") } func TestTransportParametersWithoutRetrySourceConnectionID(t *testing.T) { name, ev := testEventEncoding(t, &ParametersSet{ Initiator: InitiatorLocal, SentBy: protocol.PerspectiveServer, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, }) require.Equal(t, "transport:parameters_set", name) require.Equal(t, "local", ev["initiator"]) require.NotContains(t, ev, "retry_source_connection_id") } func TestTransportParametersWithPreferredAddress(t *testing.T) { t.Run("IPv4 and IPv6", func(t *testing.T) { testTransportParametersWithPreferredAddress(t, true, true) }) t.Run("IPv4 only", func(t *testing.T) { testTransportParametersWithPreferredAddress(t, true, false) }) t.Run("IPv6 only", func(t *testing.T) { testTransportParametersWithPreferredAddress(t, false, true) }) } func testTransportParametersWithPreferredAddress(t *testing.T, hasIPv4, hasIPv6 bool) { addr4 := netip.AddrPortFrom(netip.AddrFrom4([4]byte{12, 34, 56, 78}), 123) addr6 := netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 456) preferredAddress := &PreferredAddress{ ConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), StatelessResetToken: protocol.StatelessResetToken{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, } if hasIPv4 { preferredAddress.IPv4 = addr4 } if hasIPv6 { preferredAddress.IPv6 = addr6 } name, ev := testEventEncoding(t, &ParametersSet{ Initiator: InitiatorLocal, SentBy: protocol.PerspectiveServer, PreferredAddress: preferredAddress, }) require.Equal(t, "transport:parameters_set", name) require.Equal(t, "local", ev["initiator"]) require.Contains(t, ev, "preferred_address") pa := ev["preferred_address"].(map[string]any) if hasIPv4 { require.Equal(t, "12.34.56.78", pa["ip_v4"]) require.Equal(t, float64(123), pa["port_v4"]) } else { require.NotContains(t, pa, "ip_v4") require.NotContains(t, pa, "port_v4") } if hasIPv6 { require.Equal(t, "102:304:506:708:90a:b0c:d0e:f10", pa["ip_v6"]) require.Equal(t, float64(456), pa["port_v6"]) } else { require.NotContains(t, pa, "ip_v6") require.NotContains(t, pa, "port_v6") } require.Equal(t, "0807060504030201", pa["connection_id"]) require.Equal(t, "0f0e0d0c0b0a09080706050403020100", pa["stateless_reset_token"]) } func TestTransportParametersWithDatagramExtension(t *testing.T) { name, ev := testEventEncoding(t, &ParametersSet{ Initiator: InitiatorLocal, SentBy: protocol.PerspectiveServer, MaxDatagramFrameSize: 1337, }) require.Equal(t, "transport:parameters_set", name) require.Equal(t, float64(1337), ev["max_datagram_frame_size"]) } func TestReceivedTransportParameters(t *testing.T) { name, ev := testEventEncoding(t, &ParametersSet{ Initiator: InitiatorRemote, SentBy: protocol.PerspectiveClient, }) require.Equal(t, "transport:parameters_set", name) require.Equal(t, "remote", ev["initiator"]) require.NotContains(t, ev, "original_destination_connection_id") } func TestRestoredTransportParameters(t *testing.T) { name, ev := testEventEncoding(t, &ParametersSet{ Restore: true, InitialMaxStreamDataBidiLocal: 100, InitialMaxStreamDataBidiRemote: 200, InitialMaxStreamDataUni: 300, InitialMaxData: 400, MaxIdleTimeout: 123 * time.Millisecond, }) require.Equal(t, "transport:parameters_restored", name) require.NotContains(t, ev, "initiator") require.NotContains(t, ev, "original_destination_connection_id") require.NotContains(t, ev, "stateless_reset_token") require.NotContains(t, ev, "retry_source_connection_id") require.NotContains(t, ev, "initial_source_connection_id") require.Equal(t, float64(123), ev["max_idle_timeout"]) require.Equal(t, float64(400), ev["initial_max_data"]) require.Equal(t, float64(100), ev["initial_max_stream_data_bidi_local"]) require.Equal(t, float64(200), ev["initial_max_stream_data_bidi_remote"]) require.Equal(t, float64(300), ev["initial_max_stream_data_uni"]) } func TestPacketSent(t *testing.T) { name, ev := testEventEncoding(t, &PacketSent{ Header: PacketHeader{ PacketType: PacketTypeHandshake, PacketNumber: 1337, Version: protocol.Version1, SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), }, Raw: RawInfo{Length: 987, PayloadLength: 1337}, Frames: []Frame{ {Frame: &MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}}, {Frame: &StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}}, }, ECN: ECNCE, }) require.Equal(t, "transport:packet_sent", name) require.Contains(t, ev, "raw") raw := ev["raw"].(map[string]any) require.NotContains(t, ev, "datagram_id") require.Equal(t, float64(987), raw["length"]) require.Equal(t, float64(1337), raw["payload_length"]) require.Contains(t, ev, "header") hdr := ev["header"].(map[string]any) require.Equal(t, "handshake", hdr["packet_type"]) require.Equal(t, float64(1337), hdr["packet_number"]) require.Equal(t, "04030201", hdr["scid"]) require.Contains(t, ev, "frames") require.Equal(t, "CE", ev["ecn"]) frames := ev["frames"].([]any) require.Len(t, frames, 2) require.Equal(t, "max_stream_data", frames[0].(map[string]any)["frame_type"]) require.Equal(t, "stream", frames[1].(map[string]any)["frame_type"]) } func TestPacketSent1RTT(t *testing.T) { t.Run("with datagram ID", func(t *testing.T) { testPacketSent1RTT(t, 1337) }) t.Run("without datagram ID", func(t *testing.T) { testPacketSent1RTT(t, 0) }) } func testPacketSent1RTT(t *testing.T, datagramID DatagramID) { name, ev := testEventEncoding(t, &PacketSent{ Header: PacketHeader{ PacketType: PacketType1RTT, PacketNumber: 1337, KeyPhaseBit: KeyPhaseZero, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), }, Raw: RawInfo{Length: 123}, Frames: []Frame{ {Frame: &AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}}, {Frame: &MaxDataFrame{MaximumData: 987}}, }, ECN: ECNUnsupported, DatagramID: datagramID, }) require.Equal(t, "transport:packet_sent", name) raw := ev["raw"].(map[string]any) require.Equal(t, float64(123), raw["length"]) require.NotContains(t, raw, "payload_length") require.Contains(t, ev, "header") require.NotContains(t, ev, "ecn") hdr := ev["header"].(map[string]any) require.Equal(t, "1RTT", hdr["packet_type"]) require.Equal(t, float64(1337), hdr["packet_number"]) require.Contains(t, ev, "frames") frames := ev["frames"].([]any) require.Len(t, frames, 2) require.Equal(t, "ack", frames[0].(map[string]any)["frame_type"]) require.Equal(t, "max_data", frames[1].(map[string]any)["frame_type"]) if datagramID != 0 { require.Contains(t, ev, "datagram_id") require.Equal(t, float64(datagramID), ev["datagram_id"]) } else { require.NotContains(t, ev, "datagram_id") } } func TestPacketReceived(t *testing.T) { name, ev := testEventEncoding(t, &PacketReceived{ Header: PacketHeader{ PacketType: PacketTypeInitial, PacketNumber: 1337, Version: protocol.Version1, SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Token: &Token{Raw: []byte{0xde, 0xad, 0xbe, 0xef}}, }, Raw: RawInfo{ Length: 789, PayloadLength: 1234, }, Frames: []Frame{ {Frame: &MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}}, {Frame: &StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}}, }, ECN: ECT0, DatagramID: 42, }) require.Equal(t, "transport:packet_received", name) require.Contains(t, ev, "raw") raw := ev["raw"].(map[string]any) require.Equal(t, float64(789), raw["length"]) require.Equal(t, float64(1234), raw["payload_length"]) require.Equal(t, "ECT(0)", ev["ecn"]) require.Contains(t, ev, "header") hdr := ev["header"].(map[string]any) require.Equal(t, "initial", hdr["packet_type"]) require.Equal(t, float64(1337), hdr["packet_number"]) require.Equal(t, "04030201", hdr["scid"]) require.Contains(t, hdr, "token") token := hdr["token"].(map[string]any) require.Equal(t, "deadbeef", token["data"]) require.Contains(t, ev, "frames") require.Len(t, ev["frames"].([]any), 2) require.Contains(t, ev, "datagram_id") require.Equal(t, float64(42), ev["datagram_id"]) } func TestPacketReceived1RTT(t *testing.T) { t.Run("with datagram ID", func(t *testing.T) { testPacketReceived1RTT(t, 1337) }) t.Run("without datagram ID", func(t *testing.T) { testPacketReceived1RTT(t, 0) }) } func testPacketReceived1RTT(t *testing.T, datagramID DatagramID) { name, ev := testEventEncoding(t, &PacketReceived{ Header: PacketHeader{ PacketType: PacketType1RTT, PacketNumber: 1337, KeyPhaseBit: KeyPhaseZero, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), }, Raw: RawInfo{Length: 789, PayloadLength: 1234}, Frames: []Frame{ {Frame: &MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}}, {Frame: &StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}}, }, ECN: ECT1, DatagramID: datagramID, }) require.Equal(t, "transport:packet_received", name) require.Contains(t, ev, "raw") raw := ev["raw"].(map[string]any) require.Equal(t, float64(789), raw["length"]) require.Equal(t, float64(1234), raw["payload_length"]) require.Equal(t, "ECT(1)", ev["ecn"]) require.Contains(t, ev, "header") hdr := ev["header"].(map[string]any) require.Equal(t, "1RTT", hdr["packet_type"]) require.Equal(t, float64(1337), hdr["packet_number"]) require.Contains(t, ev, "frames") require.Len(t, ev["frames"].([]any), 2) if datagramID != 0 { require.Contains(t, ev, "datagram_id") require.Equal(t, float64(datagramID), ev["datagram_id"]) } else { require.NotContains(t, ev, "datagram_id") } } func TestPacketReceivedRetry(t *testing.T) { name, ev := testEventEncoding(t, &PacketReceived{ Header: PacketHeader{ PacketType: PacketTypeRetry, Version: protocol.Version1, SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Token: &Token{Raw: []byte{0xde, 0xad, 0xbe, 0xef}}, }, Raw: RawInfo{Length: 123}, }) require.Equal(t, "transport:packet_received", name) require.Contains(t, ev, "raw") raw := ev["raw"].(map[string]any) require.Len(t, raw, 1) require.Equal(t, float64(123), raw["length"]) require.Contains(t, ev, "header") header := ev["header"].(map[string]any) require.Equal(t, "retry", header["packet_type"]) require.NotContains(t, header, "packet_number") require.Contains(t, header, "version") require.Contains(t, header, "dcid") require.Contains(t, header, "scid") require.Contains(t, header, "token") token := header["token"].(map[string]any) require.Equal(t, "deadbeef", token["data"]) require.NotContains(t, ev, "frames") } func TestVersionNegotiationReceived(t *testing.T) { name, ev := testEventEncoding(t, &VersionNegotiationReceived{ Header: PacketHeaderVersionNegotiation{ SrcConnectionID: ArbitraryLenConnectionID{4, 3, 2, 1}, DestConnectionID: ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, }, SupportedVersions: []Version{0xdeadbeef, 0xdecafbad}, }) require.Equal(t, "transport:packet_received", name) require.Contains(t, ev, "header") require.NotContains(t, ev, "frames") require.Contains(t, ev, "supported_versions") require.Equal(t, []any{"deadbeef", "decafbad"}, ev["supported_versions"]) header := ev["header"].(map[string]any) require.Equal(t, "version_negotiation", header["packet_type"]) require.NotContains(t, header, "packet_number") require.NotContains(t, header, "version") require.Equal(t, "0102030405060708", header["dcid"]) require.Equal(t, "04030201", header["scid"]) } func TestPacketBuffered(t *testing.T) { name, ev := testEventEncoding(t, &PacketBuffered{ Header: PacketHeader{ PacketType: PacketTypeHandshake, PacketNumber: protocol.InvalidPacketNumber, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), }, Raw: RawInfo{Length: 1337}, }) require.Equal(t, "transport:packet_buffered", name) require.Contains(t, ev, "header") require.Contains(t, ev, "raw") require.Equal(t, float64(1337), ev["raw"].(map[string]any)["length"]) require.Contains(t, ev, "trigger") require.Equal(t, "keys_unavailable", ev["trigger"]) } func TestPacketDropped(t *testing.T) { name, ev := testEventEncoding(t, &PacketDropped{ Header: PacketHeader{PacketType: PacketTypeRetry}, Raw: RawInfo{Length: 1337}, Trigger: PacketDropPayloadDecryptError, }) require.Equal(t, "transport:packet_dropped", name) require.Contains(t, ev, "raw") require.Equal(t, float64(1337), ev["raw"].(map[string]any)["length"]) require.Contains(t, ev, "header") require.Equal(t, "payload_decrypt_error", ev["trigger"]) } func TestMetricsUpdated(t *testing.T) { rttStats := utils.NewRTTStats() rttStats.UpdateRTT(15*time.Millisecond, 0) rttStats.UpdateRTT(20*time.Millisecond, 0) rttStats.UpdateRTT(25*time.Millisecond, 0) name, ev := testEventEncoding(t, &MetricsUpdated{ MinRTT: rttStats.MinRTT(), SmoothedRTT: rttStats.SmoothedRTT(), LatestRTT: rttStats.LatestRTT(), RTTVariance: rttStats.MeanDeviation(), CongestionWindow: 4321, BytesInFlight: 1234, PacketsInFlight: 42, }) require.Equal(t, "recovery:metrics_updated", name) require.Equal(t, float64(15), ev["min_rtt"]) require.Equal(t, float64(25), ev["latest_rtt"]) require.Contains(t, ev, "smoothed_rtt") require.InDelta(t, rttStats.SmoothedRTT().Milliseconds(), ev["smoothed_rtt"], float64(1)) require.Contains(t, ev, "rtt_variance") require.InDelta(t, rttStats.MeanDeviation().Milliseconds(), ev["rtt_variance"], float64(1)) require.Equal(t, float64(4321), ev["congestion_window"]) require.Equal(t, float64(1234), ev["bytes_in_flight"]) require.Equal(t, float64(42), ev["packets_in_flight"]) } func TestPacketLost(t *testing.T) { name, ev := testEventEncoding(t, &PacketLost{ Header: PacketHeader{PacketType: PacketTypeHandshake, PacketNumber: 42}, Trigger: PacketLossReorderingThreshold, }) require.Equal(t, "recovery:packet_lost", name) require.Contains(t, ev, "header") require.Equal(t, "reordering_threshold", ev["trigger"]) } func TestSpuriousLoss(t *testing.T) { name, ev := testEventEncoding(t, &SpuriousLoss{ EncryptionLevel: protocol.Encryption1RTT, PacketNumber: 42, PacketReordering: 1, TimeReordering: 1337 * time.Millisecond, }) require.Equal(t, "recovery:spurious_loss", name) require.Contains(t, ev, "packet_number") require.Equal(t, float64(42), ev["packet_number"]) require.Contains(t, ev, "reordering_packets") require.Equal(t, float64(1), ev["reordering_packets"]) require.Contains(t, ev, "reordering_time") require.InDelta(t, 1337, ev["reordering_time"], float64(1)) } func TestMTUUpdated(t *testing.T) { name, ev := testEventEncoding(t, &MTUUpdated{ Value: 1337, Done: true, }) require.Equal(t, "recovery:mtu_updated", name) require.Equal(t, float64(1337), ev["mtu"]) require.Equal(t, true, ev["done"]) } func TestCongestionStateUpdated(t *testing.T) { name, ev := testEventEncoding(t, &CongestionStateUpdated{ State: CongestionStateCongestionAvoidance, }) require.Equal(t, "recovery:congestion_state_updated", name) require.Equal(t, "congestion_avoidance", ev["new"]) } func TestPTOCountUpdated(t *testing.T) { name, ev := testEventEncoding(t, &PTOCountUpdated{PTOCount: 42}) require.Equal(t, "recovery:metrics_updated", name) require.Equal(t, float64(42), ev["pto_count"]) } func TestKeyUpdatedTLS(t *testing.T) { name, ev := testEventEncoding(t, &KeyUpdated{ Trigger: KeyUpdateTLS, KeyType: KeyTypeClientHandshake, KeyPhase: 0, }) require.Equal(t, "security:key_updated", name) require.Equal(t, "client_handshake_secret", ev["key_type"]) require.Equal(t, "tls", ev["trigger"]) require.NotContains(t, ev, "key_phase") require.NotContains(t, ev, "old") require.NotContains(t, ev, "new") } func TestKeyUpdatedTLS1RTT(t *testing.T) { name, ev := testEventEncoding(t, &KeyUpdated{ Trigger: KeyUpdateTLS, KeyType: KeyTypeServer1RTT, KeyPhase: 0, }) require.Equal(t, "security:key_updated", name) require.Equal(t, "server_1rtt_secret", ev["key_type"]) require.Equal(t, "tls", ev["trigger"]) require.Equal(t, float64(0), ev["key_phase"]) require.NotContains(t, ev, "old") require.NotContains(t, ev, "new") } func TestKeyUpdated(t *testing.T) { name, ev := testEventEncoding(t, &KeyUpdated{ Trigger: KeyUpdateRemote, KeyType: KeyTypeClient1RTT, KeyPhase: 1337, }) require.Equal(t, "security:key_updated", name) require.Equal(t, float64(1337), ev["key_phase"]) require.Equal(t, "remote_update", ev["trigger"]) require.Contains(t, ev, "key_type") require.Equal(t, "client_1rtt_secret", ev["key_type"]) } func TestKeyDiscarded0RTT(t *testing.T) { name, ev := testEventEncoding(t, &KeyDiscarded{ KeyType: KeyTypeServer0RTT, KeyPhase: 0, }) require.Equal(t, "security:key_discarded", name) require.Equal(t, "tls", ev["trigger"]) require.Equal(t, "server_0rtt_secret", ev["key_type"]) } func TestKeyDiscarded(t *testing.T) { name, ev := testEventEncoding(t, &KeyDiscarded{ KeyType: KeyTypeClient1RTT, KeyPhase: 42, }) require.Equal(t, "security:key_discarded", name) require.Equal(t, float64(42), ev["key_phase"]) require.NotContains(t, ev, "trigger") require.Contains(t, ev, "key_type") require.Equal(t, "client_1rtt_secret", ev["key_type"]) } func TestLossTimerUpdated(t *testing.T) { synctest.Test(t, func(t *testing.T) { var buf bytes.Buffer tr := qlogwriter.NewConnectionFileSeq( nopWriteCloser(&buf), true, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), []string{EventSchema}, ) go tr.Run() producer := tr.AddProducer() synctest.Wait() time.Sleep(42 * time.Second) producer.RecordEvent(&LossTimerUpdated{ Type: LossTimerUpdateTypeSet, TimerType: TimerTypePTO, EncLevel: protocol.EncryptionHandshake, Time: time.Now().Add(1337 * time.Second), }) producer.Close() name, ev := decode(t, buf.String()) require.Equal(t, "recovery:loss_timer_updated", name) require.Len(t, ev, 4) require.Equal(t, "set", ev["event_type"]) require.Equal(t, "pto", ev["timer_type"]) require.Equal(t, "handshake", ev["packet_number_space"]) require.Contains(t, ev, "delta") delta := time.Duration(ev["delta"].(float64)*1e6) * time.Nanosecond require.Equal(t, 1337*time.Second, delta) }) } func TestLossTimerUpdatedExpired(t *testing.T) { name, ev := testEventEncoding(t, &LossTimerUpdated{ Type: LossTimerUpdateTypeExpired, TimerType: TimerTypeACK, EncLevel: protocol.Encryption1RTT, }) require.Equal(t, "recovery:loss_timer_updated", name) require.Len(t, ev, 3) require.Equal(t, "expired", ev["event_type"]) require.Equal(t, "ack", ev["timer_type"]) require.Equal(t, "application_data", ev["packet_number_space"]) } func TestLossTimerUpdatedCanceled(t *testing.T) { name, ev := testEventEncoding(t, &eventLossTimerCanceled{}) require.Equal(t, "recovery:loss_timer_updated", name) require.Len(t, ev, 1) require.Equal(t, "cancelled", ev["event_type"]) } func TestECNStateUpdated(t *testing.T) { name, ev := testEventEncoding(t, &ECNStateUpdated{ State: ECNStateUnknown, Trigger: "", }) require.Equal(t, "recovery:ecn_state_updated", name) require.Len(t, ev, 1) require.Equal(t, "unknown", ev["new"]) } func TestECNStateUpdatedWithTrigger(t *testing.T) { name, ev := testEventEncoding(t, &ECNStateUpdated{ State: ECNStateFailed, Trigger: "ACK doesn't contain ECN marks", }) require.Equal(t, "recovery:ecn_state_updated", name) require.Len(t, ev, 2) require.Equal(t, "failed", ev["new"]) require.Equal(t, "ACK doesn't contain ECN marks", ev["trigger"]) } func TestALPNInformation(t *testing.T) { name, ev := testEventEncoding(t, &ALPNInformation{ ChosenALPN: "h3", }) require.Equal(t, "transport:alpn_information", name) require.Len(t, ev, 1) require.Equal(t, "h3", ev["chosen_alpn"]) } func TestDebugEvent(t *testing.T) { t.Run("default name", func(t *testing.T) { name, ev := testEventEncoding(t, &DebugEvent{Message: "hello world"}) require.Equal(t, "transport:debug", name) require.Len(t, ev, 1) require.Equal(t, "hello world", ev["message"]) }) t.Run("custom name", func(t *testing.T) { name, ev := testEventEncoding(t, &DebugEvent{EventName: "foo", Message: "bar"}) require.Equal(t, "transport:foo", name) require.Len(t, ev, 1) require.Equal(t, "bar", ev["message"]) }) } quic-go-0.59.0/qlog/frame.go000066400000000000000000000371001513066070600155410ustar00rootroot00000000000000package qlog import ( "encoding/hex" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) type Frame struct { Frame any } type frames []Frame type ( // An AckFrame is an ACK frame. AckFrame = wire.AckFrame // A ConnectionCloseFrame is a CONNECTION_CLOSE frame. ConnectionCloseFrame = wire.ConnectionCloseFrame // A DataBlockedFrame is a DATA_BLOCKED frame. DataBlockedFrame = wire.DataBlockedFrame // A HandshakeDoneFrame is a HANDSHAKE_DONE frame. HandshakeDoneFrame = wire.HandshakeDoneFrame // A MaxDataFrame is a MAX_DATA frame. MaxDataFrame = wire.MaxDataFrame // A MaxStreamDataFrame is a MAX_STREAM_DATA frame. MaxStreamDataFrame = wire.MaxStreamDataFrame // A MaxStreamsFrame is a MAX_STREAMS_FRAME. MaxStreamsFrame = wire.MaxStreamsFrame // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame. NewConnectionIDFrame = wire.NewConnectionIDFrame // A NewTokenFrame is a NEW_TOKEN frame. NewTokenFrame = wire.NewTokenFrame // A PathChallengeFrame is a PATH_CHALLENGE frame. PathChallengeFrame = wire.PathChallengeFrame // A PathResponseFrame is a PATH_RESPONSE frame. PathResponseFrame = wire.PathResponseFrame // A PingFrame is a PING frame. PingFrame = wire.PingFrame // A ResetStreamFrame is a RESET_STREAM frame. ResetStreamFrame = wire.ResetStreamFrame // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame. RetireConnectionIDFrame = wire.RetireConnectionIDFrame // A StopSendingFrame is a STOP_SENDING frame. StopSendingFrame = wire.StopSendingFrame // A StreamsBlockedFrame is a STREAMS_BLOCKED frame. StreamsBlockedFrame = wire.StreamsBlockedFrame // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame. StreamDataBlockedFrame = wire.StreamDataBlockedFrame // An AckFrequencyFrame is an ACK_FREQUENCY frame. AckFrequencyFrame = wire.AckFrequencyFrame // An ImmediateAckFrame is an IMMEDIATE_ACK frame. ImmediateAckFrame = wire.ImmediateAckFrame ) type AckRange = wire.AckRange // A CryptoFrame is a CRYPTO frame. type CryptoFrame struct { Offset int64 Length int64 } // A StreamFrame is a STREAM frame. type StreamFrame struct { StreamID StreamID Offset int64 Length int64 Fin bool } // A DatagramFrame is a DATAGRAM frame. type DatagramFrame struct { Length int64 } func (fs frames) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginArray) for _, f := range fs { if err := f.Encode(enc); err != nil { return err } } h.WriteToken(jsontext.EndArray) return h.err } func (f Frame) Encode(enc *jsontext.Encoder) error { switch frame := f.Frame.(type) { case *PingFrame: return encodePingFrame(enc, frame) case *AckFrame: return encodeAckFrame(enc, frame) case *ResetStreamFrame: return encodeResetStreamFrame(enc, frame) case *StopSendingFrame: return encodeStopSendingFrame(enc, frame) case *CryptoFrame: return encodeCryptoFrame(enc, frame) case *NewTokenFrame: return encodeNewTokenFrame(enc, frame) case *StreamFrame: return encodeStreamFrame(enc, frame) case *MaxDataFrame: return encodeMaxDataFrame(enc, frame) case *MaxStreamDataFrame: return encodeMaxStreamDataFrame(enc, frame) case *MaxStreamsFrame: return encodeMaxStreamsFrame(enc, frame) case *DataBlockedFrame: return encodeDataBlockedFrame(enc, frame) case *StreamDataBlockedFrame: return encodeStreamDataBlockedFrame(enc, frame) case *StreamsBlockedFrame: return encodeStreamsBlockedFrame(enc, frame) case *NewConnectionIDFrame: return encodeNewConnectionIDFrame(enc, frame) case *RetireConnectionIDFrame: return encodeRetireConnectionIDFrame(enc, frame) case *PathChallengeFrame: return encodePathChallengeFrame(enc, frame) case *PathResponseFrame: return encodePathResponseFrame(enc, frame) case *ConnectionCloseFrame: return encodeConnectionCloseFrame(enc, frame) case *HandshakeDoneFrame: return encodeHandshakeDoneFrame(enc, frame) case *DatagramFrame: return encodeDatagramFrame(enc, frame) case *AckFrequencyFrame: return encodeAckFrequencyFrame(enc, frame) case *ImmediateAckFrame: return encodeImmediateAckFrame(enc, frame) default: panic("unknown frame type") } } func encodePingFrame(enc *jsontext.Encoder, _ *PingFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("ping")) h.WriteToken(jsontext.EndObject) return h.err } type ackRanges []wire.AckRange func (ars ackRanges) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginArray) for _, r := range ars { if err := ackRange(r).encode(enc); err != nil { return err } } h.WriteToken(jsontext.EndArray) return h.err } type ackRange wire.AckRange func (ar ackRange) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginArray) h.WriteToken(jsontext.Int(int64(ar.Smallest))) if ar.Smallest != ar.Largest { h.WriteToken(jsontext.Int(int64(ar.Largest))) } h.WriteToken(jsontext.EndArray) return h.err } func encodeAckFrame(enc *jsontext.Encoder, f *AckFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("ack")) if f.DelayTime > 0 { h.WriteToken(jsontext.String("ack_delay")) h.WriteToken(jsontext.Float(milliseconds(f.DelayTime))) } h.WriteToken(jsontext.String("acked_ranges")) if err := ackRanges(f.AckRanges).encode(enc); err != nil { return err } hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { h.WriteToken(jsontext.String("ect0")) h.WriteToken(jsontext.Uint(f.ECT0)) h.WriteToken(jsontext.String("ect1")) h.WriteToken(jsontext.Uint(f.ECT1)) h.WriteToken(jsontext.String("ce")) h.WriteToken(jsontext.Uint(f.ECNCE)) } h.WriteToken(jsontext.EndObject) return h.err } func encodeResetStreamFrame(enc *jsontext.Encoder, f *ResetStreamFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) if f.ReliableSize > 0 { h.WriteToken(jsontext.String("reset_stream_at")) } else { h.WriteToken(jsontext.String("reset_stream")) } h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Int(int64(f.StreamID))) h.WriteToken(jsontext.String("error_code")) h.WriteToken(jsontext.Int(int64(f.ErrorCode))) h.WriteToken(jsontext.String("final_size")) h.WriteToken(jsontext.Int(int64(f.FinalSize))) if f.ReliableSize > 0 { h.WriteToken(jsontext.String("reliable_size")) h.WriteToken(jsontext.Int(int64(f.ReliableSize))) } h.WriteToken(jsontext.EndObject) return h.err } func encodeStopSendingFrame(enc *jsontext.Encoder, f *StopSendingFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("stop_sending")) h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Int(int64(f.StreamID))) h.WriteToken(jsontext.String("error_code")) h.WriteToken(jsontext.Int(int64(f.ErrorCode))) h.WriteToken(jsontext.EndObject) return h.err } func encodeCryptoFrame(enc *jsontext.Encoder, f *CryptoFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("crypto")) h.WriteToken(jsontext.String("offset")) h.WriteToken(jsontext.Int(f.Offset)) h.WriteToken(jsontext.String("length")) h.WriteToken(jsontext.Int(f.Length)) h.WriteToken(jsontext.EndObject) return h.err } func encodeNewTokenFrame(enc *jsontext.Encoder, f *NewTokenFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("new_token")) h.WriteToken(jsontext.String("token")) if err := (Token{Raw: f.Token}).encode(enc); err != nil { return err } h.WriteToken(jsontext.EndObject) return h.err } func encodeStreamFrame(enc *jsontext.Encoder, f *StreamFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("stream")) h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Int(int64(f.StreamID))) h.WriteToken(jsontext.String("offset")) h.WriteToken(jsontext.Int(f.Offset)) h.WriteToken(jsontext.String("length")) h.WriteToken(jsontext.Int(f.Length)) if f.Fin { h.WriteToken(jsontext.String("fin")) h.WriteToken(jsontext.True) } h.WriteToken(jsontext.EndObject) return h.err } func encodeMaxDataFrame(enc *jsontext.Encoder, f *MaxDataFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("max_data")) h.WriteToken(jsontext.String("maximum")) h.WriteToken(jsontext.Int(int64(f.MaximumData))) h.WriteToken(jsontext.EndObject) return h.err } func encodeMaxStreamDataFrame(enc *jsontext.Encoder, f *MaxStreamDataFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("max_stream_data")) h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Int(int64(f.StreamID))) h.WriteToken(jsontext.String("maximum")) h.WriteToken(jsontext.Int(int64(f.MaximumStreamData))) h.WriteToken(jsontext.EndObject) return h.err } func encodeMaxStreamsFrame(enc *jsontext.Encoder, f *MaxStreamsFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("max_streams")) h.WriteToken(jsontext.String("stream_type")) h.WriteToken(jsontext.String(streamType(f.Type).String())) h.WriteToken(jsontext.String("maximum")) h.WriteToken(jsontext.Int(int64(f.MaxStreamNum))) h.WriteToken(jsontext.EndObject) return h.err } func encodeDataBlockedFrame(enc *jsontext.Encoder, f *DataBlockedFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("data_blocked")) h.WriteToken(jsontext.String("limit")) h.WriteToken(jsontext.Int(int64(f.MaximumData))) h.WriteToken(jsontext.EndObject) return h.err } func encodeStreamDataBlockedFrame(enc *jsontext.Encoder, f *StreamDataBlockedFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("stream_data_blocked")) h.WriteToken(jsontext.String("stream_id")) h.WriteToken(jsontext.Int(int64(f.StreamID))) h.WriteToken(jsontext.String("limit")) h.WriteToken(jsontext.Int(int64(f.MaximumStreamData))) h.WriteToken(jsontext.EndObject) return h.err } func encodeStreamsBlockedFrame(enc *jsontext.Encoder, f *StreamsBlockedFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("streams_blocked")) h.WriteToken(jsontext.String("stream_type")) h.WriteToken(jsontext.String(streamType(f.Type).String())) h.WriteToken(jsontext.String("limit")) h.WriteToken(jsontext.Int(int64(f.StreamLimit))) h.WriteToken(jsontext.EndObject) return h.err } func encodeNewConnectionIDFrame(enc *jsontext.Encoder, f *NewConnectionIDFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("new_connection_id")) h.WriteToken(jsontext.String("sequence_number")) h.WriteToken(jsontext.Uint(f.SequenceNumber)) h.WriteToken(jsontext.String("retire_prior_to")) h.WriteToken(jsontext.Uint(f.RetirePriorTo)) h.WriteToken(jsontext.String("length")) h.WriteToken(jsontext.Int(int64(f.ConnectionID.Len()))) h.WriteToken(jsontext.String("connection_id")) h.WriteToken(jsontext.String(f.ConnectionID.String())) h.WriteToken(jsontext.String("stateless_reset_token")) h.WriteToken(jsontext.String(hex.EncodeToString(f.StatelessResetToken[:]))) h.WriteToken(jsontext.EndObject) return h.err } func encodeRetireConnectionIDFrame(enc *jsontext.Encoder, f *RetireConnectionIDFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("retire_connection_id")) h.WriteToken(jsontext.String("sequence_number")) h.WriteToken(jsontext.Uint(f.SequenceNumber)) h.WriteToken(jsontext.EndObject) return h.err } func encodePathChallengeFrame(enc *jsontext.Encoder, f *PathChallengeFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("path_challenge")) h.WriteToken(jsontext.String("data")) h.WriteToken(jsontext.String(hex.EncodeToString(f.Data[:]))) h.WriteToken(jsontext.EndObject) return h.err } func encodePathResponseFrame(enc *jsontext.Encoder, f *PathResponseFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("path_response")) h.WriteToken(jsontext.String("data")) h.WriteToken(jsontext.String(hex.EncodeToString(f.Data[:]))) h.WriteToken(jsontext.EndObject) return h.err } func encodeConnectionCloseFrame(enc *jsontext.Encoder, f *ConnectionCloseFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("connection_close")) h.WriteToken(jsontext.String("error_space")) errorSpace := "transport" if f.IsApplicationError { errorSpace = "application" } h.WriteToken(jsontext.String(errorSpace)) errName := transportError(f.ErrorCode).String() if len(errName) > 0 { h.WriteToken(jsontext.String("error_code")) h.WriteToken(jsontext.String(errName)) } else { h.WriteToken(jsontext.String("error_code")) h.WriteToken(jsontext.Uint(f.ErrorCode)) } h.WriteToken(jsontext.String("raw_error_code")) h.WriteToken(jsontext.Uint(f.ErrorCode)) h.WriteToken(jsontext.String("reason")) h.WriteToken(jsontext.String(f.ReasonPhrase)) h.WriteToken(jsontext.EndObject) return h.err } func encodeHandshakeDoneFrame(enc *jsontext.Encoder, _ *HandshakeDoneFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("handshake_done")) h.WriteToken(jsontext.EndObject) return h.err } func encodeDatagramFrame(enc *jsontext.Encoder, f *DatagramFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("datagram")) h.WriteToken(jsontext.String("length")) h.WriteToken(jsontext.Int(f.Length)) h.WriteToken(jsontext.EndObject) return h.err } func encodeAckFrequencyFrame(enc *jsontext.Encoder, f *AckFrequencyFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("ack_frequency")) h.WriteToken(jsontext.String("sequence_number")) h.WriteToken(jsontext.Uint(f.SequenceNumber)) h.WriteToken(jsontext.String("ack_eliciting_threshold")) h.WriteToken(jsontext.Uint(f.AckElicitingThreshold)) h.WriteToken(jsontext.String("request_max_ack_delay")) h.WriteToken(jsontext.Float(milliseconds(f.RequestMaxAckDelay))) h.WriteToken(jsontext.String("reordering_threshold")) h.WriteToken(jsontext.Int(int64(f.ReorderingThreshold))) h.WriteToken(jsontext.EndObject) return h.err } func encodeImmediateAckFrame(enc *jsontext.Encoder, _ *ImmediateAckFrame) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("frame_type")) h.WriteToken(jsontext.String("immediate_ack")) h.WriteToken(jsontext.EndObject) return h.err } quic-go-0.59.0/qlog/frame_test.go000066400000000000000000000210071513066070600165770ustar00rootroot00000000000000package qlog import ( "bytes" "encoding/json" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/qlogwriter/jsontext" "github.com/stretchr/testify/require" ) func check(t *testing.T, f any, expected map[string]any) { var buf bytes.Buffer enc := jsontext.NewEncoder(&buf) require.NoError(t, (Frame{Frame: f}).Encode(enc)) data := buf.Bytes() require.True(t, json.Valid(data)) checkEncoding(t, data, expected) } func TestPingFrame(t *testing.T) { check(t, &PingFrame{}, map[string]any{"frame_type": "ping"}) } func TestAckFrame(t *testing.T) { tests := []struct { name string frame *AckFrame expected map[string]any }{ { name: "with delay and single packet range", frame: &AckFrame{ DelayTime: 86 * time.Millisecond, AckRanges: []AckRange{{Smallest: 120, Largest: 120}}, }, expected: map[string]any{ "frame_type": "ack", "ack_delay": 86, "acked_ranges": [][]float64{{120}}, }, }, { name: "without delay", frame: &AckFrame{ AckRanges: []AckRange{{Smallest: 120, Largest: 120}}, }, expected: map[string]any{ "frame_type": "ack", "acked_ranges": [][]float64{{120}}, }, }, { name: "with ECN counts", frame: &AckFrame{ AckRanges: []AckRange{{Smallest: 120, Largest: 120}}, ECT0: 10, ECT1: 100, ECNCE: 1000, }, expected: map[string]any{ "frame_type": "ack", "acked_ranges": [][]float64{{120}}, "ect0": 10, "ect1": 100, "ce": 1000, }, }, { name: "with multiple ranges", frame: &AckFrame{ DelayTime: 86 * time.Millisecond, AckRanges: []AckRange{ {Smallest: 5, Largest: 50}, {Smallest: 100, Largest: 120}, }, }, expected: map[string]any{ "frame_type": "ack", "ack_delay": 86, "acked_ranges": [][]float64{ {5, 50}, {100, 120}, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { check(t, tt.frame, tt.expected) }) } } func TestResetStreamFrame(t *testing.T) { check(t, &ResetStreamFrame{ StreamID: 987, FinalSize: 1234, ErrorCode: 42, }, map[string]any{ "frame_type": "reset_stream", "stream_id": 987, "error_code": 42, "final_size": 1234, }, ) } func TestResetStreamAtFrame(t *testing.T) { check(t, &ResetStreamFrame{ StreamID: 987, FinalSize: 1234, ErrorCode: 42, ReliableSize: 999, }, map[string]any{ "frame_type": "reset_stream_at", "stream_id": 987, "error_code": 42, "final_size": 1234, "reliable_size": 999, }, ) } func TestAckFrequencyFrame(t *testing.T) { check(t, &AckFrequencyFrame{ SequenceNumber: 1337, AckElicitingThreshold: 123, RequestMaxAckDelay: 42 * time.Millisecond, ReorderingThreshold: 1234, }, map[string]any{ "frame_type": "ack_frequency", "sequence_number": 1337, "ack_eliciting_threshold": 123, "request_max_ack_delay": 42, "reordering_threshold": 1234, }, ) } func TestImmediateAckFrame(t *testing.T) { check(t, &ImmediateAckFrame{}, map[string]any{ "frame_type": "immediate_ack", }, ) } func TestStopSendingFrame(t *testing.T) { check(t, &StopSendingFrame{StreamID: 987, ErrorCode: 42}, map[string]any{ "frame_type": "stop_sending", "stream_id": 987, "error_code": 42, }, ) } func TestCryptoFrame(t *testing.T) { check(t, &CryptoFrame{Offset: 1337, Length: 6}, map[string]any{ "frame_type": "crypto", "offset": 1337, "length": 6, }, ) } func TestNewTokenFrame(t *testing.T) { check(t, &NewTokenFrame{Token: []byte{0xde, 0xad, 0xbe, 0xef}}, map[string]any{ "frame_type": "new_token", "token": map[string]any{"data": "deadbeef"}, }, ) } func TestStreamFrame(t *testing.T) { tests := []struct { name string frame *StreamFrame expected map[string]any }{ { name: "with FIN", frame: &StreamFrame{ StreamID: 42, Offset: 1337, Fin: true, Length: 9876, }, expected: map[string]any{ "frame_type": "stream", "stream_id": 42, "offset": 1337, "fin": true, "length": 9876, }, }, { name: "without FIN", frame: &StreamFrame{ StreamID: 42, Offset: 1337, Length: 3, }, expected: map[string]any{ "frame_type": "stream", "stream_id": 42, "offset": 1337, "length": 3, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { check(t, tt.frame, tt.expected) }) } } func TestMaxDataFrame(t *testing.T) { check(t, &MaxDataFrame{MaximumData: 1337}, map[string]any{ "frame_type": "max_data", "maximum": 1337, }, ) } func TestMaxStreamDataFrame(t *testing.T) { check(t, &MaxStreamDataFrame{StreamID: 1234, MaximumStreamData: 1337}, map[string]any{ "frame_type": "max_stream_data", "stream_id": 1234, "maximum": 1337, }, ) } func TestMaxStreamsFrame(t *testing.T) { check(t, &MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreamNum: 42, }, map[string]any{ "frame_type": "max_streams", "stream_type": "bidirectional", "maximum": 42, }, ) } func TestDataBlockedFrame(t *testing.T) { check(t, &DataBlockedFrame{MaximumData: 1337}, map[string]any{ "frame_type": "data_blocked", "limit": 1337, }, ) } func TestStreamDataBlockedFrame(t *testing.T) { check(t, &StreamDataBlockedFrame{ StreamID: 42, MaximumStreamData: 1337, }, map[string]any{ "frame_type": "stream_data_blocked", "stream_id": 42, "limit": 1337, }, ) } func TestStreamsBlockedFrame(t *testing.T) { check(t, &StreamsBlockedFrame{ Type: protocol.StreamTypeUni, StreamLimit: 123, }, map[string]any{ "frame_type": "streams_blocked", "stream_type": "unidirectional", "limit": 123, }, ) } func TestNewConnectionIDFrame(t *testing.T) { check(t, &NewConnectionIDFrame{ SequenceNumber: 42, RetirePriorTo: 24, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, }, map[string]any{ "frame_type": "new_connection_id", "sequence_number": 42, "retire_prior_to": 24, "length": 4, "connection_id": "deadbeef", "stateless_reset_token": "000102030405060708090a0b0c0d0e0f", }, ) } func TestRetireConnectionIDFrame(t *testing.T) { check(t, &RetireConnectionIDFrame{SequenceNumber: 1337}, map[string]any{ "frame_type": "retire_connection_id", "sequence_number": 1337, }, ) } func TestPathChallengeFrame(t *testing.T) { check(t, &PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xc0, 0x01}}, map[string]any{ "frame_type": "path_challenge", "data": "deadbeefcafec001", }, ) } func TestPathResponseFrame(t *testing.T) { check(t, &PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xc0, 0x01}}, map[string]any{ "frame_type": "path_response", "data": "deadbeefcafec001", }, ) } func TestConnectionCloseFrame(t *testing.T) { tests := []struct { name string frame *ConnectionCloseFrame expected map[string]any }{ { name: "application error code", frame: &ConnectionCloseFrame{ IsApplicationError: true, ErrorCode: 1337, ReasonPhrase: "lorem ipsum", }, expected: map[string]any{ "frame_type": "connection_close", "error_space": "application", "error_code": 1337, "raw_error_code": 1337, "reason": "lorem ipsum", }, }, { name: "transport error code", frame: &ConnectionCloseFrame{ ErrorCode: uint64(qerr.FlowControlError), ReasonPhrase: "lorem ipsum", }, expected: map[string]any{ "frame_type": "connection_close", "error_space": "transport", "error_code": "flow_control_error", "raw_error_code": int(qerr.FlowControlError), "reason": "lorem ipsum", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { check(t, tt.frame, tt.expected) }) } } func TestHandshakeDoneFrame(t *testing.T) { check(t, &HandshakeDoneFrame{}, map[string]any{ "frame_type": "handshake_done", }, ) } func TestDatagramFrame(t *testing.T) { check(t, &DatagramFrame{Length: 1337}, map[string]any{ "frame_type": "datagram", "length": 1337, }, ) } quic-go-0.59.0/qlog/json_helper_test.go000066400000000000000000000020001513066070600200050ustar00rootroot00000000000000package qlog import ( "encoding/json" "testing" "github.com/stretchr/testify/require" ) func checkEncoding(t *testing.T, data []byte, expected map[string]any) { t.Helper() m := make(map[string]any) require.NoError(t, json.Unmarshal(data, &m)) require.Len(t, m, len(expected)) for key, value := range expected { switch v := value.(type) { case bool, string, map[string]any: require.Equal(t, v, m[key]) case int: require.Equal(t, float64(v), m[key]) case [][]float64: // used in the ACK frame require.Contains(t, m, key) outerSlice, ok := m[key].([]any) require.True(t, ok) require.Len(t, outerSlice, len(v)) for i, innerExpected := range v { innerSlice, ok := outerSlice[i].([]any) require.True(t, ok) require.Len(t, innerSlice, len(innerExpected)) for j, expectedValue := range innerExpected { v, ok := innerSlice[j].(float64) require.True(t, ok) require.Equal(t, expectedValue, v) } } default: t.Fatalf("unexpected type: %T", v) } } } quic-go-0.59.0/qlog/packet_header.go000066400000000000000000000062121513066070600172260ustar00rootroot00000000000000package qlog import ( "encoding/hex" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) type Token struct { Raw []byte } func (t Token) encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("data")) h.WriteToken(jsontext.String(hex.EncodeToString(t.Raw))) h.WriteToken(jsontext.EndObject) return h.err } // PacketHeader is a QUIC packet header. type PacketHeader struct { PacketType PacketType KeyPhaseBit KeyPhaseBit PacketNumber PacketNumber Version Version SrcConnectionID ConnectionID DestConnectionID ConnectionID Token *Token } func (h PacketHeader) encode(enc *jsontext.Encoder) error { helper := encoderHelper{enc: enc} helper.WriteToken(jsontext.BeginObject) helper.WriteToken(jsontext.String("packet_type")) helper.WriteToken(jsontext.String(string(h.PacketType))) if h.PacketType != PacketTypeRetry && h.PacketType != PacketTypeVersionNegotiation && h.PacketType != "" && h.PacketNumber != protocol.InvalidPacketNumber { helper.WriteToken(jsontext.String("packet_number")) helper.WriteToken(jsontext.Int(int64(h.PacketNumber))) } if h.Version != 0 { helper.WriteToken(jsontext.String("version")) helper.WriteToken(jsontext.String(version(h.Version).String())) } if h.PacketType != PacketType1RTT { helper.WriteToken(jsontext.String("scil")) helper.WriteToken(jsontext.Int(int64(h.SrcConnectionID.Len()))) if h.SrcConnectionID.Len() > 0 { helper.WriteToken(jsontext.String("scid")) helper.WriteToken(jsontext.String(h.SrcConnectionID.String())) } } helper.WriteToken(jsontext.String("dcil")) helper.WriteToken(jsontext.Int(int64(h.DestConnectionID.Len()))) if h.DestConnectionID.Len() > 0 { helper.WriteToken(jsontext.String("dcid")) helper.WriteToken(jsontext.String(h.DestConnectionID.String())) } if h.KeyPhaseBit == KeyPhaseZero || h.KeyPhaseBit == KeyPhaseOne { helper.WriteToken(jsontext.String("key_phase_bit")) helper.WriteToken(jsontext.String(h.KeyPhaseBit.String())) } if h.Token != nil { helper.WriteToken(jsontext.String("token")) if err := h.Token.encode(enc); err != nil { return err } } helper.WriteToken(jsontext.EndObject) return helper.err } type PacketHeaderVersionNegotiation struct { SrcConnectionID ArbitraryLenConnectionID DestConnectionID ArbitraryLenConnectionID } func (h PacketHeaderVersionNegotiation) encode(enc *jsontext.Encoder) error { helper := encoderHelper{enc: enc} helper.WriteToken(jsontext.BeginObject) helper.WriteToken(jsontext.String("packet_type")) helper.WriteToken(jsontext.String("version_negotiation")) helper.WriteToken(jsontext.String("scil")) helper.WriteToken(jsontext.Int(int64(h.SrcConnectionID.Len()))) helper.WriteToken(jsontext.String("scid")) helper.WriteToken(jsontext.String(h.SrcConnectionID.String())) helper.WriteToken(jsontext.String("dcil")) helper.WriteToken(jsontext.Int(int64(h.DestConnectionID.Len()))) helper.WriteToken(jsontext.String("dcid")) helper.WriteToken(jsontext.String(h.DestConnectionID.String())) helper.WriteToken(jsontext.EndObject) return helper.err } quic-go-0.59.0/qlog/packet_header_test.go000066400000000000000000000064121513066070600202670ustar00rootroot00000000000000package qlog import ( "bytes" "encoding/json" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter/jsontext" "github.com/stretchr/testify/require" ) func checkHeader(t *testing.T, hdr *PacketHeader, expected map[string]any) { t.Helper() var buf bytes.Buffer enc := jsontext.NewEncoder(&buf) require.NoError(t, hdr.encode(enc)) data := buf.Bytes() require.True(t, json.Valid(data)) checkEncoding(t, data, expected) } func TestHeaderInitial(t *testing.T) { checkHeader(t, &PacketHeader{ PacketType: PacketTypeInitial, PacketNumber: 42, Version: protocol.Version(0xdecafbad), }, map[string]any{ "packet_type": "initial", "packet_number": 42, "dcil": 0, "scil": 0, "version": "decafbad", }, ) } func TestHeaderInitialWithToken(t *testing.T) { checkHeader(t, &PacketHeader{ PacketType: PacketTypeInitial, PacketNumber: 1337, SrcConnectionID: protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44}), DestConnectionID: protocol.ParseConnectionID([]byte{0x55, 0x66, 0x77, 0x88}), Version: protocol.Version(0xdecafbad), Token: &Token{Raw: []byte{0xde, 0xad, 0xbe, 0xef}}, }, map[string]any{ "packet_type": "initial", "packet_number": 1337, "dcil": 4, "dcid": "55667788", "scil": 4, "scid": "11223344", "version": "decafbad", "token": map[string]any{"data": "deadbeef"}, }, ) } func TestHeaderLongPacketNumbers(t *testing.T) { t.Run("packet 0", func(t *testing.T) { testHeaderPacketNumbers(t, 0) }) // This is used for events where the packet number is not yet known, // e.g. the packet_buffered event. t.Run("no packet number", func(t *testing.T) { testHeaderPacketNumbers(t, 1) }) } func testHeaderPacketNumbers(t *testing.T, pn protocol.PacketNumber) { expected := map[string]any{ "packet_type": "handshake", "dcil": 0, "scil": 0, "version": "1", } if pn != protocol.InvalidPacketNumber { expected["packet_number"] = int(pn) } checkHeader(t, &PacketHeader{ PacketType: PacketTypeHandshake, PacketNumber: pn, Version: protocol.Version1, }, expected, ) } func TestHeaderRetry(t *testing.T) { checkHeader(t, &PacketHeader{ PacketType: PacketTypeRetry, SrcConnectionID: protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44}), DestConnectionID: protocol.ParseConnectionID([]byte{0x55, 0x66, 0x77, 0x88, 0x99}), Version: protocol.Version(0xdecafbad), Token: &Token{Raw: []byte{0xde, 0xad, 0xbe, 0xef}}, }, map[string]any{ "packet_type": "retry", "dcil": 5, "dcid": "5566778899", "scil": 4, "scid": "11223344", "token": map[string]any{"data": "deadbeef"}, "version": "decafbad", }, ) } func TestHeader1RTT(t *testing.T) { checkHeader(t, &PacketHeader{ PacketType: PacketType1RTT, PacketNumber: 42, DestConnectionID: protocol.ParseConnectionID([]byte{0x55, 0x66, 0x77, 0x88}), KeyPhaseBit: KeyPhaseZero, }, map[string]any{ "packet_type": "1RTT", "packet_number": 42, "dcil": 4, "dcid": "55667788", "key_phase_bit": "0", }, ) } quic-go-0.59.0/qlog/qlog_dir.go000066400000000000000000000034331513066070600162510ustar00rootroot00000000000000package qlog import ( "bufio" "context" "fmt" "log" "os" "slices" "strings" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/qlogwriter" ) // EventSchema is the qlog event schema for QUIC const EventSchema = "urn:ietf:params:qlog:events:quic-12" // DefaultConnectionTracer creates a qlog file in the qlog directory specified by the QLOGDIR environment variable. // File names are _.sqlog. // Returns nil if QLOGDIR is not set. func DefaultConnectionTracer(_ context.Context, isClient bool, connID ConnectionID) qlogwriter.Trace { return defaultConnectionTracerWithSchemas(isClient, connID, []string{EventSchema}) } func DefaultConnectionTracerWithSchemas(_ context.Context, isClient bool, connID ConnectionID, eventSchemas []string) qlogwriter.Trace { if !slices.Contains(eventSchemas, EventSchema) { eventSchemas = append([]string{EventSchema}, eventSchemas...) } return defaultConnectionTracerWithSchemas(isClient, connID, eventSchemas) } func defaultConnectionTracerWithSchemas(isClient bool, connID ConnectionID, eventSchemas []string) qlogwriter.Trace { qlogDir := os.Getenv("QLOGDIR") if qlogDir == "" { return nil } if _, err := os.Stat(qlogDir); os.IsNotExist(err) { if err := os.MkdirAll(qlogDir, 0o755); err != nil { log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err) } } label := "server" if isClient { label = "client" } path := fmt.Sprintf("%s/%s_%s.sqlog", strings.TrimRight(qlogDir, "/"), connID, label) f, err := os.Create(path) if err != nil { log.Printf("Failed to create qlog file %s: %s", path, err.Error()) return nil } fileSeq := qlogwriter.NewConnectionFileSeq( utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), isClient, connID, eventSchemas, ) go fileSeq.Run() return fileSeq } quic-go-0.59.0/qlog/qlog_dir_test.go000066400000000000000000000041411513066070600173050ustar00rootroot00000000000000package qlog import ( "context" "encoding/json" "os" "path/filepath" "strings" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/require" ) func TestQLOGDIRSet(t *testing.T) { tmpDir := t.TempDir() connID, _ := protocol.GenerateConnectionIDForInitial() qlogDir := filepath.Join(tmpDir, "qlogs") t.Setenv("QLOGDIR", qlogDir) t.Run("default connection tracer", func(t *testing.T) { tracer := DefaultConnectionTracer(context.Background(), true, connID) testQLOGDIRSet(t, qlogDir, tracer, []string{EventSchema}) }) t.Run("default connection tracer with schemas", func(t *testing.T) { tracer := DefaultConnectionTracerWithSchemas(context.Background(), true, connID, []string{"urn:ietf:params:qlog:events:foobar"}) testQLOGDIRSet(t, qlogDir, tracer, []string{EventSchema, "urn:ietf:params:qlog:events:foobar"}) }) } func testQLOGDIRSet(t *testing.T, qlogDir string, tracer qlogwriter.Trace, expectedEventSchemas []string) { require.NotNil(t, tracer) // adding and closing a producer makes the tracer close the file recorder := tracer.AddProducer() recorder.Close() _, err := os.Stat(qlogDir) qlogDirCreated := !os.IsNotExist(err) require.True(t, qlogDirCreated) entries, err := os.ReadDir(qlogDir) require.NoError(t, err) require.Len(t, entries, 1) data, err := os.ReadFile(filepath.Join(qlogDir, entries[0].Name())) require.NoError(t, err) var obj map[string]any require.NoError(t, json.Unmarshal([]byte(strings.Split(string(data), "\n")[0])[1:], &obj)) require.Contains(t, obj, "trace") require.IsType(t, obj["trace"], map[string]any{}) require.Contains(t, obj["trace"], "event_schemas") var eventSchemas []string for _, v := range obj["trace"].(map[string]any)["event_schemas"].([]any) { eventSchemas = append(eventSchemas, v.(string)) } require.Equal(t, eventSchemas, expectedEventSchemas) } func TestQLOGDIRNotSet(t *testing.T) { connID, _ := protocol.GenerateConnectionIDForInitial() t.Setenv("QLOGDIR", "") tracer := DefaultConnectionTracer(context.Background(), true, connID) require.Nil(t, tracer) } quic-go-0.59.0/qlog/types.go000066400000000000000000000246551513066070600156260ustar00rootroot00000000000000package qlog import ( "fmt" "hash/crc32" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" ) type ( ConnectionID = protocol.ConnectionID ArbitraryLenConnectionID = protocol.ArbitraryLenConnectionID Version = protocol.Version PacketNumber = protocol.PacketNumber EncryptionLevel = protocol.EncryptionLevel KeyPhaseBit = protocol.KeyPhaseBit KeyPhase = protocol.KeyPhase StreamID = protocol.StreamID TransportErrorCode = qerr.TransportErrorCode ApplicationErrorCode = qerr.ApplicationErrorCode ) const ( // KeyPhaseZero is key phase bit 0 KeyPhaseZero = protocol.KeyPhaseZero // KeyPhaseOne is key phase bit 1 KeyPhaseOne = protocol.KeyPhaseOne ) // ECN represents the Explicit Congestion Notification value. type ECN string const ( // ECNUnsupported means that no ECN value was set / received ECNUnsupported ECN = "" // ECTNot is Not-ECT ECTNot ECN = "Not-ECT" // ECT0 is ECT(0) ECT0 ECN = "ECT(0)" // ECT1 is ECT(1) ECT1 ECN = "ECT(1)" // ECNCE is CE ECNCE ECN = "CE" ) type Initiator string const ( InitiatorLocal Initiator = "local" InitiatorRemote Initiator = "remote" ) type streamType protocol.StreamType func (s streamType) String() string { switch protocol.StreamType(s) { case protocol.StreamTypeUni: return "unidirectional" case protocol.StreamTypeBidi: return "bidirectional" default: return "unknown stream type" } } type version protocol.Version func (v version) String() string { return fmt.Sprintf("%x", uint32(v)) } func encLevelToPacketNumberSpace(encLevel protocol.EncryptionLevel) string { switch encLevel { case protocol.EncryptionInitial: return "initial" case protocol.EncryptionHandshake: return "handshake" case protocol.Encryption0RTT, protocol.Encryption1RTT: return "application_data" default: return "unknown encryption level" } } // KeyType represents the type of cryptographic key used in QUIC connections. type KeyType string const ( // KeyTypeServerInitial represents the server's initial secret key. KeyTypeServerInitial KeyType = "server_initial_secret" // KeyTypeClientInitial represents the client's initial secret key. KeyTypeClientInitial KeyType = "client_initial_secret" // KeyTypeServerHandshake represents the server's handshake secret key. KeyTypeServerHandshake KeyType = "server_handshake_secret" // KeyTypeClientHandshake represents the client's handshake secret key. KeyTypeClientHandshake KeyType = "client_handshake_secret" // KeyTypeServer0RTT represents the server's 0-RTT secret key. KeyTypeServer0RTT KeyType = "server_0rtt_secret" // KeyTypeClient0RTT represents the client's 0-RTT secret key. KeyTypeClient0RTT KeyType = "client_0rtt_secret" // KeyTypeServer1RTT represents the server's 1-RTT secret key. KeyTypeServer1RTT KeyType = "server_1rtt_secret" // KeyTypeClient1RTT represents the client's 1-RTT secret key. KeyTypeClient1RTT KeyType = "client_1rtt_secret" ) // KeyUpdateTrigger describes what caused a key update event. type KeyUpdateTrigger string const ( // KeyUpdateTLS indicates the key update was triggered by TLS. KeyUpdateTLS KeyUpdateTrigger = "tls" // KeyUpdateRemote indicates the key update was triggered by the remote peer. KeyUpdateRemote KeyUpdateTrigger = "remote_update" // KeyUpdateLocal indicates the key update was triggered locally. KeyUpdateLocal KeyUpdateTrigger = "local_update" ) type transportError uint64 func (e transportError) String() string { switch qerr.TransportErrorCode(e) { case qerr.NoError: return "no_error" case qerr.InternalError: return "internal_error" case qerr.ConnectionRefused: return "connection_refused" case qerr.FlowControlError: return "flow_control_error" case qerr.StreamLimitError: return "stream_limit_error" case qerr.StreamStateError: return "stream_state_error" case qerr.FinalSizeError: return "final_size_error" case qerr.FrameEncodingError: return "frame_encoding_error" case qerr.TransportParameterError: return "transport_parameter_error" case qerr.ConnectionIDLimitError: return "connection_id_limit_error" case qerr.ProtocolViolation: return "protocol_violation" case qerr.InvalidToken: return "invalid_token" case qerr.ApplicationErrorErrorCode: return "application_error" case qerr.CryptoBufferExceeded: return "crypto_buffer_exceeded" case qerr.KeyUpdateError: return "key_update_error" case qerr.AEADLimitReached: return "aead_limit_reached" case qerr.NoViablePathError: return "no_viable_path" default: return "" } } type PacketType string const ( // PacketTypeInitial represents an Initial packet PacketTypeInitial PacketType = "initial" // PacketTypeHandshake represents a Handshake packet PacketTypeHandshake PacketType = "handshake" // PacketTypeRetry represents a Retry packet PacketTypeRetry PacketType = "retry" // PacketType0RTT represents a 0-RTT packet PacketType0RTT PacketType = "0RTT" // PacketTypeVersionNegotiation represents a Version Negotiation packet PacketTypeVersionNegotiation PacketType = "version_negotiation" // PacketTypeStatelessReset represents a Stateless Reset packet PacketTypeStatelessReset PacketType = "stateless_reset" // PacketType1RTT represents a 1-RTT packet PacketType1RTT PacketType = "1RTT" // // PacketTypeNotDetermined represents a packet type that could not be determined // PacketTypeNotDetermined packetType = "" ) func EncryptionLevelToPacketType(l EncryptionLevel) PacketType { switch l { case protocol.EncryptionInitial: return PacketTypeInitial case protocol.EncryptionHandshake: return PacketTypeHandshake case protocol.Encryption0RTT: return PacketType0RTT case protocol.Encryption1RTT: return PacketType1RTT default: panic(fmt.Sprintf("unknown encryption level: %d", l)) } } type PacketLossReason string const ( // PacketLossReorderingThreshold is used when a packet is declared lost due to reordering threshold PacketLossReorderingThreshold PacketLossReason = "reordering_threshold" // PacketLossTimeThreshold is used when a packet is declared lost due to time threshold PacketLossTimeThreshold PacketLossReason = "time_threshold" ) type PacketDropReason string const ( // PacketDropKeyUnavailable is used when a packet is dropped because keys are unavailable PacketDropKeyUnavailable PacketDropReason = "key_unavailable" // PacketDropUnknownConnectionID is used when a packet is dropped because the connection ID is unknown PacketDropUnknownConnectionID PacketDropReason = "unknown_connection_id" // PacketDropHeaderParseError is used when a packet is dropped because header parsing failed PacketDropHeaderParseError PacketDropReason = "header_parse_error" // PacketDropPayloadDecryptError is used when a packet is dropped because decrypting the payload failed PacketDropPayloadDecryptError PacketDropReason = "payload_decrypt_error" // PacketDropProtocolViolation is used when a packet is dropped due to a protocol violation PacketDropProtocolViolation PacketDropReason = "protocol_violation" // PacketDropDOSPrevention is used when a packet is dropped to mitigate a DoS attack PacketDropDOSPrevention PacketDropReason = "dos_prevention" // PacketDropUnsupportedVersion is used when a packet is dropped because the version is not supported PacketDropUnsupportedVersion PacketDropReason = "unsupported_version" // PacketDropUnexpectedPacket is used when an unexpected packet is received PacketDropUnexpectedPacket PacketDropReason = "unexpected_packet" // PacketDropUnexpectedSourceConnectionID is used when a packet with an unexpected source connection ID is received PacketDropUnexpectedSourceConnectionID PacketDropReason = "unexpected_source_connection_id" // PacketDropUnexpectedVersion is used when a packet with an unexpected version is received PacketDropUnexpectedVersion PacketDropReason = "unexpected_version" // PacketDropDuplicate is used when a duplicate packet is received PacketDropDuplicate PacketDropReason = "duplicate" ) type LossTimerUpdateType string const ( LossTimerUpdateTypeSet LossTimerUpdateType = "set" LossTimerUpdateTypeExpired LossTimerUpdateType = "expired" LossTimerUpdateTypeCancelled LossTimerUpdateType = "cancelled" ) type TimerType string const ( // TimerTypeACK represents an ACK timer TimerTypeACK TimerType = "ack" // TimerTypePTO represents a PTO (Probe Timeout) timer TimerTypePTO TimerType = "pto" // TimerTypePathProbe represents a path probe timer TimerTypePathProbe TimerType = "path_probe" ) type CongestionState string const ( // CongestionStateSlowStart is the slow start phase of Reno / Cubic CongestionStateSlowStart CongestionState = "slow_start" // CongestionStateCongestionAvoidance is the congestion avoidance phase of Reno / Cubic CongestionStateCongestionAvoidance CongestionState = "congestion_avoidance" // CongestionStateRecovery is the recovery phase of Reno / Cubic CongestionStateRecovery CongestionState = "recovery" // CongestionStateApplicationLimited means that the congestion controller is application limited CongestionStateApplicationLimited CongestionState = "application_limited" ) func (s CongestionState) String() string { return string(s) } // ECNState is the state of the ECN state machine (see Appendix A.4 of RFC 9000) type ECNState string const ( // ECNStateTesting is the testing state ECNStateTesting ECNState = "testing" // ECNStateUnknown is the unknown state ECNStateUnknown ECNState = "unknown" // ECNStateFailed is the failed state ECNStateFailed ECNState = "failed" // ECNStateCapable is the capable state ECNStateCapable ECNState = "capable" ) type ConnectionCloseTrigger string const ( // IdleTimeout indicates the connection was closed due to idle timeout ConnectionCloseTriggerIdleTimeout ConnectionCloseTrigger = "idle_timeout" // Application indicates the connection was closed by the application ConnectionCloseTriggerApplication ConnectionCloseTrigger = "application" // VersionMismatch indicates the connection was closed due to a QUIC version mismatch ConnectionCloseTriggerVersionMismatch ConnectionCloseTrigger = "version_mismatch" // StatelessReset indicates the connection was closed due to receiving a stateless reset from the peer ConnectionCloseTriggerStatelessReset ConnectionCloseTrigger = "stateless_reset" ) // DatagramID is a unique identifier for a datagram type DatagramID uint32 // CalculateDatagramID computes a DatagramID for a given packet func CalculateDatagramID(packet []byte) DatagramID { return DatagramID(crc32.ChecksumIEEE(packet)) } quic-go-0.59.0/qlog/types_test.go000066400000000000000000000012511513066070600166500ustar00rootroot00000000000000package qlog import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestEncryptionLevelToPacketType(t *testing.T) { require.Equal(t, "initial", string(EncryptionLevelToPacketType(protocol.EncryptionInitial))) require.Equal(t, "handshake", string(EncryptionLevelToPacketType(protocol.EncryptionHandshake))) require.Equal(t, "0RTT", string(EncryptionLevelToPacketType(protocol.Encryption0RTT))) require.Equal(t, "1RTT", string(EncryptionLevelToPacketType(protocol.Encryption1RTT))) } func TestCalculateDatagramID(t *testing.T) { require.Equal(t, DatagramID(0xcbf43926), CalculateDatagramID([]byte("123456789"))) } quic-go-0.59.0/qlogwriter/000077500000000000000000000000001513066070600153545ustar00rootroot00000000000000quic-go-0.59.0/qlogwriter/jsontext/000077500000000000000000000000001513066070600172325ustar00rootroot00000000000000quic-go-0.59.0/qlogwriter/jsontext/encoder.go000066400000000000000000000161101513066070600211770ustar00rootroot00000000000000// Package jsontext provides a fast JSON encoder providing only the necessary features // for qlog encoding. No efforts are made to add any features beyond qlog's requirements. // // The API aims to be compatible with the standard library's encoding/json/jsontext package. package jsontext import ( "fmt" "io" "strconv" "unsafe" ) type kind uint8 const ( kindString kind = iota kindInt kindUint kindFloat kindBool kindNull kindObjectStart kindObjectEnd kindArrayStart kindArrayEnd ) // Token represents a JSON token. type Token struct { kind kind str string i64 int64 u64 uint64 f64 float64 b bool } // String creates a string token. func String(s string) Token { return Token{kind: kindString, str: s} } // Int creates an int token. func Int(i int64) Token { return Token{kind: kindInt, i64: i} } // Uint creates a uint token. func Uint(u uint64) Token { return Token{kind: kindUint, u64: u} } // Float creates a float token. func Float(f float64) Token { return Token{kind: kindFloat, f64: f} } // Bool creates a bool token. func Bool(b bool) Token { return Token{kind: kindBool, b: b} } // Null is a null token. var Null Token = Token{kind: kindNull} // BeginObject is the begin object token. var BeginObject Token = Token{kind: kindObjectStart} // EndObject is the end object token. var EndObject Token = Token{kind: kindObjectEnd} // BeginArray is the begin array token. var BeginArray Token = Token{kind: kindArrayStart} // EndArray is the end array token. var EndArray Token = Token{kind: kindArrayEnd} // True is a true token. var True Token = Bool(true) // False is a false token. var False Token = Bool(false) var hexDigits = [16]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'} var ( commaByte = []byte(",") quoteByte = []byte(`"`) colonByte = []byte(":") trueByte = []byte("true") falseByte = []byte("false") nullByte = []byte("null") openObjectByte = []byte("{") closeObjectByte = []byte("}") openArrayByte = []byte("[") closeArrayByte = []byte("]") newlineByte = []byte("\n") escapeQuote = []byte(`\"`) escapeBackslash = []byte(`\\`) escapeBackspace = []byte(`\b`) escapeFormfeed = []byte(`\f`) escapeNewline = []byte(`\n`) escapeCarriage = []byte(`\r`) escapeTab = []byte(`\t`) escapeUnicode = []byte(`\u00`) ) type context struct { isObject bool needsComma bool expectKey bool } // Encoder encodes JSON to an io.Writer. type Encoder struct { w io.Writer buf [64]byte // scratch buffer for number formatting stack []context } // NewEncoder creates a new Encoder. func NewEncoder(w io.Writer) *Encoder { stack := make([]context, 0, 8) stack = append(stack, context{isObject: false, needsComma: false, expectKey: false}) return &Encoder{ w: w, stack: stack, } } // WriteToken writes a token to the encoder. func (e *Encoder) WriteToken(t Token) error { if len(e.stack) == 0 { return fmt.Errorf("empty stack") } curr := &e.stack[len(e.stack)-1] isClosing := t.kind == kindObjectEnd || t.kind == kindArrayEnd if !isClosing && curr.needsComma { if _, err := e.w.Write(commaByte); err != nil { return err } curr.needsComma = false } var err error switch t.kind { case kindString: data := stringToBytes(t.str) needsEscape := false for _, b := range data { if b == '"' || b == '\\' || b < 0x20 { needsEscape = true break } } if !needsEscape { if _, err = e.w.Write(quoteByte); err != nil { return err } if _, err = e.w.Write(data); err != nil { return err } if _, err = e.w.Write(quoteByte); err != nil { return err } } else { if _, err = e.w.Write(quoteByte); err != nil { return err } for i := 0; i < len(t.str); i++ { c := t.str[i] switch c { case '"': if _, err = e.w.Write(escapeQuote); err != nil { return err } case '\\': if _, err = e.w.Write(escapeBackslash); err != nil { return err } case '\b': if _, err = e.w.Write(escapeBackspace); err != nil { return err } case '\f': if _, err = e.w.Write(escapeFormfeed); err != nil { return err } case '\n': if _, err = e.w.Write(escapeNewline); err != nil { return err } case '\r': if _, err = e.w.Write(escapeCarriage); err != nil { return err } case '\t': if _, err = e.w.Write(escapeTab); err != nil { return err } default: if c < 0x20 { if _, err = e.w.Write(escapeUnicode); err != nil { return err } if _, err = e.w.Write([]byte{hexDigits[c>>4], hexDigits[c&0xf]}); err != nil { return err } } else { if _, err = e.w.Write([]byte{c}); err != nil { return err } } } } if _, err = e.w.Write(quoteByte); err != nil { return err } } if curr.isObject { if curr.expectKey { // key if _, err = e.w.Write(colonByte); err != nil { return err } curr.expectKey = false return nil // do not call afterValue for keys } else { // value e.afterValue() } } else { e.afterValue() } case kindInt: b := strconv.AppendInt(e.buf[:0], t.i64, 10) if _, err = e.w.Write(b); err != nil { return err } e.afterValue() case kindUint: b := strconv.AppendUint(e.buf[:0], t.u64, 10) if _, err = e.w.Write(b); err != nil { return err } e.afterValue() case kindFloat: b := strconv.AppendFloat(e.buf[:0], t.f64, 'g', -1, 64) if _, err = e.w.Write(b); err != nil { return err } e.afterValue() case kindBool: if t.b { if _, err = e.w.Write(trueByte); err != nil { return err } } else { if _, err = e.w.Write(falseByte); err != nil { return err } } e.afterValue() case kindNull: if _, err = e.w.Write(nullByte); err != nil { return err } e.afterValue() case kindObjectStart: if _, err = e.w.Write(openObjectByte); err != nil { return err } e.stack = append(e.stack, context{isObject: true, needsComma: false, expectKey: true}) return nil case kindObjectEnd: if _, err = e.w.Write(closeObjectByte); err != nil { return err } e.stack = e.stack[:len(e.stack)-1] e.afterValue() if len(e.stack) == 1 { if _, err = e.w.Write(newlineByte); err != nil { return err } } return nil case kindArrayStart: if _, err = e.w.Write(openArrayByte); err != nil { return err } e.stack = append(e.stack, context{isObject: false, needsComma: false, expectKey: false}) return nil case kindArrayEnd: if _, err = e.w.Write(closeArrayByte); err != nil { return err } e.stack = e.stack[:len(e.stack)-1] e.afterValue() if len(e.stack) == 1 { if _, err = e.w.Write(newlineByte); err != nil { return err } } return nil default: return fmt.Errorf("unknown token kind") } return err } // afterValue updates the state after encoding a value func (e *Encoder) afterValue() { if len(e.stack) > 1 { curr := &e.stack[len(e.stack)-1] curr.needsComma = true if curr.isObject { curr.expectKey = true } } } func stringToBytes(s string) []byte { return unsafe.Slice(unsafe.StringData(s), len(s)) } quic-go-0.59.0/qlogwriter/jsontext/encoder_test.go000066400000000000000000000252531513066070600222460ustar00rootroot00000000000000package jsontext_test import ( "bytes" "encoding/json" "testing" "github.com/quic-go/quic-go/qlogwriter/jsontext" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncoderSimpleObject(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginObject) enc.WriteToken(jsontext.String("foo")) enc.WriteToken(jsontext.String("bar")) enc.WriteToken(jsontext.String("foo2")) enc.WriteToken(jsontext.String("bar2")) enc.WriteToken(jsontext.EndObject) output := buf.String() var got map[string]string require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, map[string]string{"foo": "bar", "foo2": "bar2"}, got) } func TestEncoderArrayInts(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginArray) enc.WriteToken(jsontext.Int(1)) enc.WriteToken(jsontext.Int(2)) enc.WriteToken(jsontext.Int(3)) enc.WriteToken(jsontext.EndArray) output := buf.String() var got []int require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, []int{1, 2, 3}, got) } func TestEncoderArrayStrings(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginArray) enc.WriteToken(jsontext.String("one")) enc.WriteToken(jsontext.String("two")) enc.WriteToken(jsontext.EndArray) output := buf.String() var got []string err := json.Unmarshal([]byte(output), &got) require.NoError(t, err) require.Equal(t, []string{"one", "two"}, got) } func TestEncoderNestedObject(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginObject) enc.WriteToken(jsontext.String("outer")) enc.WriteToken(jsontext.BeginObject) enc.WriteToken(jsontext.String("inner")) enc.WriteToken(jsontext.String("value")) enc.WriteToken(jsontext.EndObject) enc.WriteToken(jsontext.EndObject) output := buf.String() var got map[string]map[string]string require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, map[string]map[string]string{"outer": {"inner": "value"}}, got) } func TestEncoderNumbersAndBool(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginObject) enc.WriteToken(jsontext.String("int")) enc.WriteToken(jsontext.Int(42)) enc.WriteToken(jsontext.String("uint")) enc.WriteToken(jsontext.Uint(100)) enc.WriteToken(jsontext.String("float")) enc.WriteToken(jsontext.Float(3.14)) enc.WriteToken(jsontext.String("true")) enc.WriteToken(jsontext.True) enc.WriteToken(jsontext.String("false")) enc.WriteToken(jsontext.False) enc.WriteToken(jsontext.String("nullv")) enc.WriteToken(jsontext.Null) enc.WriteToken(jsontext.EndObject) output := buf.String() var got map[string]any require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, map[string]any{ "int": float64(42), // json.Unmarshal decodes numbers as float64 "uint": float64(100), "float": 3.14, "true": true, "false": false, "nullv": nil, }, got) } func TestEncoderEmptyObject(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginObject) enc.WriteToken(jsontext.EndObject) output := buf.String() var got map[string]any require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, map[string]any{}, got) } func TestEncoderEmptyArray(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginArray) enc.WriteToken(jsontext.EndArray) output := buf.String() var got []any require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, []any{}, got) } func TestEncoderArrayWithNulls(t *testing.T) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginArray) enc.WriteToken(jsontext.Null) enc.WriteToken(jsontext.String("x")) enc.WriteToken(jsontext.Null) enc.WriteToken(jsontext.EndArray) output := buf.String() var got []any require.NoError(t, json.Unmarshal([]byte(output), &got)) require.Equal(t, []any{nil, "x", nil}, got) } func TestEncoderEscapedStrings(t *testing.T) { t.Run("no escapes", func(t *testing.T) { testEncoderEscapedStrings(t, "simplekey", "simplevalue") }) t.Run("basic escapes", func(t *testing.T) { key := `key"\/` value := `value"\/` testEncoderEscapedStrings(t, key, value) }) t.Run("control characters", func(t *testing.T) { key := "key\b\f\n\r\t" value := "value\b\f\n\r\t" testEncoderEscapedStrings(t, key, value) }) t.Run("unicode low", func(t *testing.T) { key := "key\u0007\u001f" value := "value\u0007\u001f" testEncoderEscapedStrings(t, key, value) }) t.Run("mixed all", func(t *testing.T) { key := `key"\\\/\b\f\n\r\t\u0007\u001f` value := `value"\\\/\b\f\n\r\t\u0007\u001f` testEncoderEscapedStrings(t, key, value) }) } func testEncoderEscapedStrings(t *testing.T, key, value string) { buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) enc.WriteToken(jsontext.BeginObject) enc.WriteToken(jsontext.String(key)) enc.WriteToken(jsontext.String(value)) enc.WriteToken(jsontext.EndObject) output := buf.String() var got map[string]string err := json.Unmarshal([]byte(output), &got) require.NoError(t, err) expected := map[string]string{key: value} require.Equal(t, expected, got) } func encodeValue(t testing.TB, enc *jsontext.Encoder, v any) (isSupported bool) { t.Helper() switch val := v.(type) { case map[string]any: require.NoError(t, enc.WriteToken(jsontext.BeginObject)) for k, vv := range val { require.NoError(t, enc.WriteToken(jsontext.String(k))) if !encodeValue(t, enc, vv) { return false } } require.NoError(t, enc.WriteToken(jsontext.EndObject)) return true case []any: require.NoError(t, enc.WriteToken(jsontext.BeginArray)) for _, vv := range val { if !encodeValue(t, enc, vv) { return false // Propagate unsupported if any nested value fails } } require.NoError(t, enc.WriteToken(jsontext.EndArray)) return true case string: require.NoError(t, enc.WriteToken(jsontext.String(val))) return true case int64: require.NoError(t, enc.WriteToken(jsontext.Int(val))) return true case uint64: require.NoError(t, enc.WriteToken(jsontext.Uint(val))) return true case float64: require.NoError(t, enc.WriteToken(jsontext.Float(val))) return true case bool: require.NoError(t, enc.WriteToken(jsontext.Bool(val))) return true case nil: require.NoError(t, enc.WriteToken(jsontext.Null)) return true default: return false } } type errorWriter struct { N int } func (w *errorWriter) Write(p []byte) (int, error) { n := min(len(p), w.N) w.N -= n if w.N <= 0 { return n, assert.AnError } return n, nil } func TestEncoderComprehensive(t *testing.T) { // encodes an object with all token types and nested structures encode := func(enc *jsontext.Encoder) error { if err := enc.WriteToken(jsontext.BeginObject); err != nil { return err } if err := enc.WriteToken(jsontext.String("simple")); err != nil { return err } if err := enc.WriteToken(jsontext.String("value")); err != nil { return err } if err := enc.WriteToken(jsontext.String("escaped")); err != nil { return err } if err := enc.WriteToken(jsontext.String(`"quoted\"string"`)); err != nil { return err } if err := enc.WriteToken(jsontext.String("int")); err != nil { return err } if err := enc.WriteToken(jsontext.Int(-42)); err != nil { return err } if err := enc.WriteToken(jsontext.String("uint")); err != nil { return err } if err := enc.WriteToken(jsontext.Uint(100)); err != nil { return err } if err := enc.WriteToken(jsontext.String("float")); err != nil { return err } if err := enc.WriteToken(jsontext.Float(3.14)); err != nil { return err } if err := enc.WriteToken(jsontext.String("true")); err != nil { return err } if err := enc.WriteToken(jsontext.True); err != nil { return err } if err := enc.WriteToken(jsontext.String("false")); err != nil { return err } if err := enc.WriteToken(jsontext.False); err != nil { return err } if err := enc.WriteToken(jsontext.String("array")); err != nil { return err } if err := enc.WriteToken(jsontext.BeginArray); err != nil { return err } if err := enc.WriteToken(jsontext.String("item1")); err != nil { return err } if err := enc.WriteToken(jsontext.Int(1)); err != nil { return err } if err := enc.WriteToken(jsontext.EndArray); err != nil { return err } if err := enc.WriteToken(jsontext.String("nested")); err != nil { return err } if err := enc.WriteToken(jsontext.BeginObject); err != nil { return err } if err := enc.WriteToken(jsontext.EndObject); err != nil { return err } if err := enc.WriteToken(jsontext.EndObject); err != nil { return err } return nil } buf := bytes.NewBuffer(nil) enc := jsontext.NewEncoder(buf) require.NoError(t, encode(enc)) for i := range buf.Len() { enc := jsontext.NewEncoder(&errorWriter{N: i}) require.ErrorIs(t, encode(enc), assert.AnError) } } func FuzzEncoder(f *testing.F) { examples := []string{ `{"hello": "world"}`, `{"foo": 123, "bar": [1, 2, 3]}`, `{"nested": {"a": 1, "b": [true, false, "foobar", null]}}`, `[{"x": 1}, {"y": "foo"}]`, `["foo", "bar"]`, `["a", {"b": [1, 2, {"c": "d"}]}, 3]`, `{"emptyObj": {}, "emptyArr": []}`, `{"mixed": [1, "two", {"three": 3}]}`, `[null]`, } for _, tc := range examples { // first test that // 1. it's valid JSON d := json.NewDecoder(bytes.NewReader([]byte(tc))) var expected any require.NoError(f, d.Decode(&expected), "corpus entry `%s` is not valid JSON", tc) // 2. the jsontext encoder can handle enc := jsontext.NewEncoder(&bytes.Buffer{}) require.True(f, encodeValue(f, enc, expected), "expected `%s` to be supported", tc) f.Add([]byte(tc)) } var stdlibBuf, ourBuf bytes.Buffer f.Fuzz(func(t *testing.T, b []byte) { stdlibBuf.Truncate(0) ourBuf.Truncate(0) stdlibBuf.Grow(len(b)) ourBuf.Grow(len(b)) d := json.NewDecoder(bytes.NewReader(b)) var expected any if err := d.Decode(&expected); err != nil { return // invalid JSON } // only attempt to handle inputs that the standard library can handle stdlibEnc := json.NewEncoder(&stdlibBuf) require.NoError(t, stdlibEnc.Encode(expected)) if !json.Valid(stdlibBuf.Bytes()) { return } // then encode using the jsontext encoder enc := jsontext.NewEncoder(&ourBuf) if isSupported := encodeValue(t, enc, expected); !isSupported { return } output := ourBuf.Bytes() require.Truef(t, json.Valid(output), "produced invalid JSON: %s", output) var got any require.NoError(t, json.Unmarshal(output, &got)) require.JSONEq(t, ourBuf.String(), stdlibBuf.String()) }) } quic-go-0.59.0/qlogwriter/trace.go000066400000000000000000000067521513066070600170130ustar00rootroot00000000000000package qlogwriter import ( "runtime/debug" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) type ConnectionID = protocol.ConnectionID // Setting of this only works when quic-go is used as a library. // When building a binary from this repository, the version can be set using the following go build flag: // -ldflags="-X github.com/quic-go/quic-go/qlogwriter.quicGoVersion=foobar" var quicGoVersion = "(devel)" func init() { if quicGoVersion != "(devel)" { // variable set by ldflags return } info, ok := debug.ReadBuildInfo() if !ok { // no build info available. This happens when quic-go is not used as a library. return } for _, d := range info.Deps { if d.Path == "github.com/quic-go/quic-go" { quicGoVersion = d.Version if d.Replace != nil { if len(d.Replace.Version) > 0 { quicGoVersion = d.Version } else { quicGoVersion += " (replaced)" } } break } } } type encoderHelper struct { enc *jsontext.Encoder err error } func (h *encoderHelper) WriteToken(t jsontext.Token) { if h.err != nil { return } h.err = h.enc.WriteToken(t) } type traceHeader struct { VantagePointType string GroupID *ConnectionID ReferenceTime time.Time EventSchemas []string } func (l traceHeader) Encode(enc *jsontext.Encoder) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("file_schema")) h.WriteToken(jsontext.String("urn:ietf:params:qlog:file:sequential")) h.WriteToken(jsontext.String("serialization_format")) h.WriteToken(jsontext.String("application/qlog+json-seq")) h.WriteToken(jsontext.String("title")) h.WriteToken(jsontext.String("quic-go qlog")) h.WriteToken(jsontext.String("code_version")) h.WriteToken(jsontext.String(quicGoVersion)) h.WriteToken(jsontext.String("trace")) // trace h.WriteToken(jsontext.BeginObject) if len(l.EventSchemas) > 0 { h.WriteToken(jsontext.String("event_schemas")) h.WriteToken(jsontext.BeginArray) for _, schema := range l.EventSchemas { h.WriteToken(jsontext.String(schema)) } h.WriteToken(jsontext.EndArray) } h.WriteToken(jsontext.String("vantage_point")) // -- vantage_point h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("type")) h.WriteToken(jsontext.String(l.VantagePointType)) // -- end vantage_point h.WriteToken(jsontext.EndObject) h.WriteToken(jsontext.String("common_fields")) // -- common_fields h.WriteToken(jsontext.BeginObject) if l.GroupID != nil { h.WriteToken(jsontext.String("group_id")) h.WriteToken(jsontext.String(l.GroupID.String())) } h.WriteToken(jsontext.String("reference_time")) // ---- reference_time h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("clock_type")) h.WriteToken(jsontext.String("monotonic")) h.WriteToken(jsontext.String("epoch")) h.WriteToken(jsontext.String("unknown")) h.WriteToken(jsontext.String("wall_clock_time")) h.WriteToken(jsontext.String(l.ReferenceTime.Format(time.RFC3339Nano))) // ---- end reference_time h.WriteToken(jsontext.EndObject) // -- end common_fields h.WriteToken(jsontext.EndObject) // end trace h.WriteToken(jsontext.EndObject) // The following fields are not required by the qlog draft anymore, // but qvis still requires them to be present. h.WriteToken(jsontext.String("qlog_format")) h.WriteToken(jsontext.String("JSON-SEQ")) h.WriteToken(jsontext.String("qlog_version")) h.WriteToken(jsontext.String("0.3")) h.WriteToken(jsontext.EndObject) return h.err } quic-go-0.59.0/qlogwriter/trace_test.go000066400000000000000000000065361513066070600200520ustar00rootroot00000000000000package qlogwriter import ( "bytes" "encoding/json" "io" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/stretchr/testify/require" ) type nopWriteCloserImpl struct{ io.Writer } func (nopWriteCloserImpl) Close() error { return nil } func nopWriteCloser(w io.Writer) io.WriteCloser { return &nopWriteCloserImpl{Writer: w} } func unmarshal(data []byte, v any) error { if bytes.Equal(data[:1], recordSeparator) { data = data[1:] } return json.Unmarshal(data, v) } func TestTraceMetadata(t *testing.T) { t.Run("non-connection trace", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { buf := &bytes.Buffer{} trace := NewFileSeq(nopWriteCloser(buf)) go trace.Run() producer := trace.AddProducer() producer.Close() testTraceMetadata(t, buf, "transport", "", []string{}) }) }) t.Run("connection trace", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { buf := &bytes.Buffer{} trace := NewConnectionFileSeq( nopWriteCloser(buf), false, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), []string{"urn:ietf:params:qlog:events:foo", "urn:ietf:params:qlog:events:bar"}, ) require.False(t, trace.SupportsSchemas("urn:ietf:params:qlog:events:baz")) require.True(t, trace.SupportsSchemas("urn:ietf:params:qlog:events:foo")) require.True(t, trace.SupportsSchemas("urn:ietf:params:qlog:events:bar")) go trace.Run() producer := trace.AddProducer() producer.Close() testTraceMetadata(t, buf, "server", "deadbeef", []string{"urn:ietf:params:qlog:events:foo", "urn:ietf:params:qlog:events:bar"}, ) }) }) } func testTraceMetadata(t *testing.T, buf *bytes.Buffer, expectedVantagePoint, expectedGroupID string, expectedEventSchemas []string, ) { var m map[string]any require.NoError(t, unmarshal(buf.Bytes(), &m)) require.Equal(t, "0.3", m["qlog_version"]) require.Contains(t, m, "title") require.Contains(t, m, "trace") tr := m["trace"].(map[string]any) require.Contains(t, tr, "common_fields") commonFields := tr["common_fields"].(map[string]any) if expectedGroupID != "" { require.Contains(t, commonFields, "group_id") require.Equal(t, expectedGroupID, commonFields["group_id"]) } else { require.NotContains(t, commonFields, "group_id") } require.Contains(t, commonFields, "reference_time") referenceTimeMap := commonFields["reference_time"].(map[string]any) require.Contains(t, referenceTimeMap, "clock_type") require.Equal(t, "monotonic", referenceTimeMap["clock_type"]) require.Contains(t, referenceTimeMap, "epoch") require.Equal(t, "unknown", referenceTimeMap["epoch"]) require.Contains(t, referenceTimeMap, "wall_clock_time") wallClockTimeStr := referenceTimeMap["wall_clock_time"].(string) wallClockTime, err := time.Parse(time.RFC3339Nano, wallClockTimeStr) require.NoError(t, err) require.Equal(t, time.Now().UTC(), wallClockTime.UTC()) require.Contains(t, tr, "vantage_point") vantagePoint := tr["vantage_point"].(map[string]any) require.Equal(t, expectedVantagePoint, vantagePoint["type"]) if len(expectedEventSchemas) > 0 { require.Contains(t, tr, "event_schemas") eventSchemas := tr["event_schemas"].([]any) for i, schema := range eventSchemas { require.Equal(t, expectedEventSchemas[i], schema) } } else { require.NotContains(t, tr, "event_schemas") } } quic-go-0.59.0/qlogwriter/writer.go000066400000000000000000000126151513066070600172240ustar00rootroot00000000000000package qlogwriter import ( "bytes" "fmt" "io" "log" "slices" "sync" "time" "github.com/quic-go/quic-go/qlogwriter/jsontext" ) // Trace represents a qlog trace that can have multiple event producers. // Each producer can record events to the trace independently. // When the last producer is closed, the underlying trace is closed as well. type Trace interface { // AddProducer creates a new Recorder for this trace. // Each Recorder can record events independently. AddProducer() Recorder // SupportsSchemas returns true if the trace supports the given schema. SupportsSchemas(schema string) bool } // Recorder is used to record events to a qlog trace. // It is safe for concurrent use by multiple goroutines. type Recorder interface { // RecordEvent records a single Event to the trace. // It must not be called after Close. RecordEvent(Event) // Close signals that this producer is done recording events. // When all producers are closed, the underlying trace is closed. // It must not be called concurrently with RecordEvent. io.Closer } // Event represents a qlog event that can be encoded to JSON. // Each event must provide its name and a method to encode itself using a jsontext.Encoder. type Event interface { // Name returns the name of the event, as it should appear in the qlog output Name() string // Encode writes the event's data to the provided jsontext.Encoder Encode(encoder *jsontext.Encoder, eventTime time.Time) error } // RecordSeparator is the record separator byte for the JSON-SEQ format const RecordSeparator byte = 0x1e var recordSeparator = []byte{RecordSeparator} type event struct { Time time.Time Event Event } const eventChanSize = 50 // FileSeq represents a qlog trace using the JSON-SEQ format, // https://www.ietf.org/archive/id/draft-ietf-quic-qlog-main-schema-12.html#section-5 // qlog event producers can be created by calling AddProducer. // The underlying io.WriteCloser is closed when the last producer is removed. type FileSeq struct { w io.WriteCloser enc *jsontext.Encoder referenceTime time.Time runStopped chan struct{} encodeErr error events chan event done chan struct{} mx sync.Mutex producers int closed bool eventSchemas []string } var _ Trace = &FileSeq{} // NewFileSeq creates a new JSON-SEQ qlog trace to log transport events. func NewFileSeq(w io.WriteCloser) *FileSeq { return newFileSeq(w, "transport", nil, nil) } // NewConnectionFileSeq creates a new qlog trace to log connection events. func NewConnectionFileSeq(w io.WriteCloser, isClient bool, odcid ConnectionID, eventSchemas []string) *FileSeq { pers := "server" if isClient { pers = "client" } return newFileSeq(w, pers, &odcid, eventSchemas) } func newFileSeq(w io.WriteCloser, pers string, odcid *ConnectionID, eventSchemas []string) *FileSeq { now := time.Now() buf := &bytes.Buffer{} enc := jsontext.NewEncoder(buf) if _, err := buf.Write(recordSeparator); err != nil { panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) } if err := (&traceHeader{ VantagePointType: pers, GroupID: odcid, ReferenceTime: now, EventSchemas: eventSchemas, }).Encode(enc); err != nil { panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) } _, encodeErr := w.Write(buf.Bytes()) return &FileSeq{ w: w, referenceTime: now, enc: jsontext.NewEncoder(w), runStopped: make(chan struct{}), encodeErr: encodeErr, events: make(chan event, eventChanSize), done: make(chan struct{}), eventSchemas: eventSchemas, } } func (t *FileSeq) SupportsSchemas(schema string) bool { return slices.Contains(t.eventSchemas, schema) } func (t *FileSeq) AddProducer() Recorder { t.mx.Lock() defer t.mx.Unlock() if t.closed { return nil } t.producers++ return &Writer{t: t} } func (t *FileSeq) record(eventTime time.Time, details Event) { t.mx.Lock() if t.closed { t.mx.Unlock() return } t.mx.Unlock() t.events <- event{Time: eventTime, Event: details} } func (t *FileSeq) Run() { defer close(t.runStopped) for { select { case <-t.done: for { select { case e := <-t.events: t.encodeEvent(e) default: if t.encodeErr != nil { log.Printf("exporting qlog failed: %s\n", t.encodeErr) } return } } case e := <-t.events: t.encodeEvent(e) } } } func (t *FileSeq) encodeEvent(e event) { if t.encodeErr != nil { return } if _, err := t.w.Write(recordSeparator); err != nil { t.encodeErr = err return } h := encoderHelper{enc: t.enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("time")) h.WriteToken(jsontext.Float(float64(e.Time.Sub(t.referenceTime).Nanoseconds()) / 1e6)) h.WriteToken(jsontext.String("name")) h.WriteToken(jsontext.String(e.Event.Name())) h.WriteToken(jsontext.String("data")) if err := e.Event.Encode(t.enc, e.Time); err != nil { t.encodeErr = err return } h.WriteToken(jsontext.EndObject) if h.err != nil { t.encodeErr = h.err } } func (t *FileSeq) removeProducer() { t.mx.Lock() t.producers-- last := t.producers == 0 if last { t.closed = true } t.mx.Unlock() if last { close(t.done) <-t.runStopped // wait for Run to drain and exit _ = t.w.Close() } } type Writer struct { t *FileSeq } func (w *Writer) Close() error { w.t.removeProducer() return nil } func (w *Writer) RecordEvent(ev Event) { w.t.record(time.Now(), ev) } quic-go-0.59.0/qlogwriter/writer_test.go000066400000000000000000000050101513066070600202520ustar00rootroot00000000000000package qlogwriter import ( "bytes" "errors" "fmt" "io" "log" "os" "testing" "time" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/qlogwriter/jsontext" "github.com/stretchr/testify/require" ) type testEvent struct { message string } func (e testEvent) Name() string { return "transport:test_event" } func (e testEvent) Encode(enc *jsontext.Encoder, _ time.Time) error { h := encoderHelper{enc: enc} h.WriteToken(jsontext.BeginObject) h.WriteToken(jsontext.String("message")) h.WriteToken(jsontext.String(e.message)) h.WriteToken(jsontext.EndObject) return h.err } type limitedWriter struct { io.WriteCloser N int written int } func (w *limitedWriter) Write(p []byte) (int, error) { if w.written+len(p) > w.N { return 0, errors.New("writer full") } n, err := w.WriteCloser.Write(p) w.written += n return n, err } func TestWritingStopping(t *testing.T) { buf := &bytes.Buffer{} fileSeq := NewFileSeq(&limitedWriter{WriteCloser: nopWriteCloser(buf), N: 250}) writer := fileSeq.AddProducer() go fileSeq.Run() for i := range 1000 { writer.RecordEvent(testEvent{message: fmt.Sprintf("test message %d", i)}) } var logBuf bytes.Buffer log.SetOutput(&logBuf) defer log.SetOutput(os.Stdout) writer.Close() require.Contains(t, logBuf.String(), "writer full") // events after closing are ignored logBuf.Reset() writer.RecordEvent(testEvent{message: "foobar"}) require.Empty(t, logBuf.String()) } type blockingWriter struct { bytes.Buffer block bool unblock chan struct{} } func (w *blockingWriter) Write(b []byte) (int, error) { if w.block { <-w.unblock } return w.Buffer.Write(b) } // TestRecordCloseRace triggers a race between record and Close. func TestRecordCloseRace(t *testing.T) { synctest.Test(t, func(t *testing.T) { w := &blockingWriter{unblock: make(chan struct{})} trace := NewFileSeq(nopWriteCloser(w)) go trace.Run() synctest.Wait() // Run is blocked waiting for events producer := trace.AddProducer() require.NotNil(t, producer) w.block = true const numEvents = eventChanSize + 1 for i := range numEvents { producer.RecordEvent(testEvent{message: fmt.Sprintf("event %d", i)}) } go producer.RecordEvent(testEvent{message: "last event"}) synctest.Wait() // goroutine is blocked on full channel close(w.unblock) // let Run() finish producer.Close() for i := range numEvents { require.Contains(t, w.String(), fmt.Sprintf(`"message":"event %d"`, i)) } require.Contains(t, w.String(), `"message":"last event"`) }) } quic-go-0.59.0/quic_linux_test.go000066400000000000000000000002331513066070600167210ustar00rootroot00000000000000//go:build linux package quic import ( "fmt" ) func init() { major, minor := kernelVersion() fmt.Printf("Kernel Version: %d.%d\n\n", major, minor) } quic-go-0.59.0/quic_test.go000066400000000000000000000043561513066070600155140ustar00rootroot00000000000000package quic import ( "bytes" "fmt" "net" "os" "runtime/pprof" "strconv" "strings" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) // in the tests for the stream deadlines we set a deadline // and wait to make an assertion when Read / Write was unblocked // on the CIs, the timing is a lot less precise, so scale every duration by this factor func scaleDuration(t time.Duration) time.Duration { scaleFactor := 1 if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set scaleFactor = f } if scaleFactor == 0 { panic("TIMESCALE_FACTOR is 0") } return time.Duration(scaleFactor) * t } func newUDPConnLocalhost(t testing.TB) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) return conn } func getPacket(t *testing.T, connID protocol.ConnectionID) []byte { return getPacketWithPacketType(t, connID, protocol.PacketTypeHandshake, 2) } func getPacketWithPacketType(t *testing.T, connID protocol.ConnectionID, typ protocol.PacketType, length protocol.ByteCount) []byte { t.Helper() b, err := (&wire.ExtendedHeader{ Header: wire.Header{ Type: typ, DestConnectionID: connID, Length: length, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen2, }).Append(nil, protocol.Version1) require.NoError(t, err) return append(b, bytes.Repeat([]byte{42}, int(length)-2)...) } func areConnsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) return strings.Contains(b.String(), "quic-go.(*connection).run") } func areTransportsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) return strings.Contains(b.String(), "quic-go.(*Transport).listen") } func TestMain(m *testing.M) { status := m.Run() if status != 0 { os.Exit(status) } if areConnsRunning() { fmt.Println("stray connection goroutines found") os.Exit(1) } if areTransportsRunning() { fmt.Println("stray transport goroutines found") os.Exit(1) } os.Exit(status) } quic-go-0.59.0/quicvarint/000077500000000000000000000000001513066070600153425ustar00rootroot00000000000000quic-go-0.59.0/quicvarint/io.go000066400000000000000000000036071513066070600163060ustar00rootroot00000000000000package quicvarint import ( "bytes" "io" ) // Reader implements both the io.ByteReader and io.Reader interfaces. type Reader interface { io.ByteReader io.Reader } var _ Reader = &bytes.Reader{} // A Peeker can peek bytes without consuming them. type Peeker interface { Peek(b []byte) (int, error) } // Peek reads a number in the QUIC varint format without consuming bytes. func Peek(p Peeker) (uint64, error) { var b [8]byte // first peek 1 byte to determine the varint length if _, err := p.Peek(b[:1]); err != nil { return 0, err } l := 1 << (b[0] >> 6) // 1, 2, 4, or 8 bytes if l == 1 { return uint64(b[0] & 0b00111111), nil } if _, err := p.Peek(b[:l]); err != nil { return 0, err } val, _, err := Parse(b[:l]) return val, err } type byteReader struct { io.Reader } var _ Reader = &byteReader{} // NewReader returns a Reader for r. // If r already implements both io.ByteReader and io.Reader, NewReader returns r. // Otherwise, r is wrapped to add the missing interfaces. func NewReader(r io.Reader) Reader { if r, ok := r.(Reader); ok { return r } return &byteReader{r} } func (r *byteReader) ReadByte() (byte, error) { var b [1]byte var n int var err error for n == 0 && err == nil { n, err = r.Read(b[:]) } if n == 1 && err == io.EOF { err = nil } return b[0], err } // Writer implements both the io.ByteWriter and io.Writer interfaces. type Writer interface { io.ByteWriter io.Writer } var _ Writer = &bytes.Buffer{} type byteWriter struct { io.Writer } var _ Writer = &byteWriter{} // NewWriter returns a Writer for w. // If w already implements both io.ByteWriter and io.Writer, NewWriter returns w. // Otherwise, w is wrapped to add the missing interfaces. func NewWriter(w io.Writer) Writer { if w, ok := w.(Writer); ok { return w } return &byteWriter{w} } func (w *byteWriter) WriteByte(c byte) error { _, err := w.Write([]byte{c}) return err } quic-go-0.59.0/quicvarint/io_test.go000066400000000000000000000070511513066070600173420ustar00rootroot00000000000000package quicvarint import ( "bytes" "fmt" "io" "testing" "github.com/stretchr/testify/require" ) type nopReader struct{} func (r *nopReader) Read(_ []byte) (int, error) { return 0, io.ErrUnexpectedEOF } var _ io.Reader = &nopReader{} type nopWriter struct{} func (r *nopWriter) Write(_ []byte) (int, error) { return 0, io.ErrShortBuffer } // eofReader is a reader that returns data and the io.EOF at the same time in the last Read call type eofReader struct { Data []byte pos int } func (r *eofReader) Read(b []byte) (int, error) { n := copy(b, r.Data[r.pos:]) r.pos += n if r.pos >= len(r.Data) { return n, io.EOF } return n, nil } var _ io.Writer = &nopWriter{} func TestReaderPassesThroughUnchanged(t *testing.T) { b := bytes.NewReader([]byte{0}) r := NewReader(b) require.Equal(t, b, r) } func TestReaderWrapsIOReader(t *testing.T) { n := &nopReader{} r := NewReader(n) require.NotEqual(t, n, r) } func TestReaderFailure(t *testing.T) { r := NewReader(&nopReader{}) val, err := r.ReadByte() require.Equal(t, io.ErrUnexpectedEOF, err) require.Equal(t, byte(0), val) } func TestReaderHandlesEOF(t *testing.T) { // test that the eofReader behaves as we expect r := &eofReader{Data: []byte("foobar")} b := make([]byte, 3) n, err := r.Read(b) require.Equal(t, 3, n) require.NoError(t, err) require.Equal(t, "foo", string(b)) n, err = r.Read(b) require.Equal(t, 3, n) require.Equal(t, io.EOF, err) require.Equal(t, "bar", string(b)) n, err = r.Read(b) require.Equal(t, io.EOF, err) require.Zero(t, n) // now test using it to read varints reader := NewReader(&eofReader{Data: Append(nil, 1337)}) n2, err := Read(reader) require.NoError(t, err) require.EqualValues(t, 1337, n2) } // Regression test: empty reads were being converted to successful // reads of a zero value. func TestReaderHandlesEmptyRead(t *testing.T) { r, w := io.Pipe() go func() { // io.Pipe turns empty writes into empty reads. w.Write(nil) w.Close() }() br := NewReader(r) _, err := Read(br) require.ErrorIs(t, err, io.EOF) } func TestWriterPassesThroughUnchanged(t *testing.T) { b := &bytes.Buffer{} w := NewWriter(b) require.Equal(t, b, w) } func TestWriterWrapsIOWriter(t *testing.T) { n := &nopWriter{} w := NewWriter(n) require.NotEqual(t, n, w) } func TestWriterFailure(t *testing.T) { w := NewWriter(&nopWriter{}) err := w.WriteByte(0) require.Equal(t, io.ErrShortBuffer, err) } type bufPeeker []byte func (p bufPeeker) Peek(b []byte) (int, error) { if len(p) < len(b) { return copy(b, p), io.ErrUnexpectedEOF } return copy(b, p), nil } func TestPeek(t *testing.T) { for _, c := range []bufPeeker{ {0b00011001}, // 1-byte {0b01111011, 0xbd}, // 2-byte {0b10011101, 0x7f, 0x3e, 0x7d}, // 4-byte {0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}, // 8-byte } { t.Run(fmt.Sprintf("%d bytes", len(c)), func(t *testing.T) { peekVal, err := Peek(append(c, []byte("foobar")...)) // append some data, which doesn't matter require.NoError(t, err) parseVal, _, err := Parse(c) require.NoError(t, err) require.Equal(t, parseVal, peekVal) }) } } func TestPeekErrors(t *testing.T) { errorCases := []struct { name string input bufPeeker }{ {"empty input", bufPeeker{}}, {"2-byte, missing 1", bufPeeker{0b01000001}}, {"4-byte, missing 1", bufPeeker{0b10000000, 0, 0}}, {"8-byte, missing 1", bufPeeker{0b11000000, 0, 0, 0, 0, 0, 0}}, } for _, tc := range errorCases { t.Run(tc.name, func(t *testing.T) { _, err := Peek(tc.input) require.ErrorIs(t, err, io.ErrUnexpectedEOF) }) } } quic-go-0.59.0/quicvarint/varint.go000066400000000000000000000101271513066070600171750ustar00rootroot00000000000000package quicvarint import ( "encoding/binary" "fmt" "io" ) // taken from the QUIC draft const ( // Min is the minimum value allowed for a QUIC varint. Min = 0 // Max is the maximum allowed value for a QUIC varint (2^62-1). Max = maxVarInt8 maxVarInt1 = 63 maxVarInt2 = 16383 maxVarInt4 = 1073741823 maxVarInt8 = 4611686018427387903 ) type varintLengthError struct { Num uint64 } func (e *varintLengthError) Error() string { return fmt.Sprintf("value doesn't fit into 62 bits: %d", e.Num) } // Read reads a number in the QUIC varint format from r. func Read(r io.ByteReader) (uint64, error) { firstByte, err := r.ReadByte() if err != nil { return 0, err } // the first two bits of the first byte encode the length l := 1 << ((firstByte & 0xc0) >> 6) b1 := firstByte & (0xff - 0xc0) if l == 1 { return uint64(b1), nil } b2, err := r.ReadByte() if err != nil { return 0, err } if l == 2 { return uint64(b2) + uint64(b1)<<8, nil } b3, err := r.ReadByte() if err != nil { return 0, err } b4, err := r.ReadByte() if err != nil { return 0, err } if l == 4 { return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil } b5, err := r.ReadByte() if err != nil { return 0, err } b6, err := r.ReadByte() if err != nil { return 0, err } b7, err := r.ReadByte() if err != nil { return 0, err } b8, err := r.ReadByte() if err != nil { return 0, err } return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } // Parse reads a number in the QUIC varint format. // It returns the number of bytes consumed. func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) { if len(b) == 0 { return 0, 0, io.EOF } first := b[0] switch first >> 6 { case 0: // 1-byte encoding: 00xxxxxx return uint64(first & 0b00111111), 1, nil case 1: // 2-byte encoding: 01xxxxxx if len(b) < 2 { return 0, 0, io.ErrUnexpectedEOF } return uint64(b[1]) | uint64(first&0b00111111)<<8, 2, nil case 2: // 4-byte encoding: 10xxxxxx if len(b) < 4 { return 0, 0, io.ErrUnexpectedEOF } return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(first&0b00111111)<<24, 4, nil case 3: // 8-byte encoding: 00xxxxxx if len(b) < 8 { return 0, 0, io.ErrUnexpectedEOF } // binary.BigEndian.Uint64 only reads the first 8 bytes. Passing the full slice avoids slicing overhead. return binary.BigEndian.Uint64(b) & 0x3fffffffffffffff, 8, nil } panic("unreachable") } // Append appends i in the QUIC varint format. func Append(b []byte, i uint64) []byte { if i <= maxVarInt1 { return append(b, uint8(i)) } if i <= maxVarInt2 { return append(b, []byte{uint8(i>>8) | 0x40, uint8(i)}...) } if i <= maxVarInt4 { return append(b, []byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}...) } if i <= maxVarInt8 { return append(b, []byte{ uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), }...) } panic(&varintLengthError{Num: i}) } // AppendWithLen append i in the QUIC varint format with the desired length. func AppendWithLen(b []byte, i uint64, length int) []byte { if length != 1 && length != 2 && length != 4 && length != 8 { panic("invalid varint length") } l := Len(i) if l == length { return Append(b, i) } if l > length { panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) } switch length { case 2: b = append(b, 0b01000000) case 4: b = append(b, 0b10000000) case 8: b = append(b, 0b11000000) } for range length - l - 1 { b = append(b, 0) } for j := range l { b = append(b, uint8(i>>(8*(l-1-j)))) } return b } // Len determines the number of bytes that will be needed to write the number i. // //gcassert:inline func Len(i uint64) int { if i <= maxVarInt1 { return 1 } if i <= maxVarInt2 { return 2 } if i <= maxVarInt4 { return 4 } if i <= maxVarInt8 { return 8 } // Don't use a fmt.Sprintf here to format the error message. // The function would then exceed the inlining budget. panic(&varintLengthError{Num: i}) } quic-go-0.59.0/quicvarint/varint_test.go000066400000000000000000000231721513066070600202400ustar00rootroot00000000000000package quicvarint import ( "bytes" "fmt" "io" "math/rand/v2" "testing" "github.com/stretchr/testify/require" ) func TestLimits(t *testing.T) { require.Equal(t, 0, Min) require.Equal(t, uint64(1<<62-1), uint64(Max)) } func TestRead(t *testing.T) { tests := []struct { name string input []byte expected uint64 }{ {"1 byte", []byte{0b00011001}, 25}, {"2 byte", []byte{0b01111011, 0xbd}, 15293}, {"4 byte", []byte{0b10011101, 0x7f, 0x3e, 0x7d}, 494878333}, {"8 byte", []byte{0b11000010, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}, 151288809941952652}, {"too long", []byte{0b01000000, 0x25}, 37}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := bytes.NewReader(tt.input) val, err := Read(b) require.NoError(t, err) require.Equal(t, tt.expected, val) require.Zero(t, b.Len()) }) } } func TestParse(t *testing.T) { tests := []struct { name string input []byte expectedValue uint64 expectedLen int }{ {"1 byte", []byte{0b00011001}, 25, 1}, {"2 byte", []byte{0b01111011, 0xbd}, 15293, 2}, {"4 byte", []byte{0b10011101, 0x7f, 0x3e, 0x7d}, 494878333, 4}, {"8 byte", []byte{0b11000010, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}, 151288809941952652, 8}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { value, l, err := Parse(tt.input) require.Equal(t, tt.expectedValue, value) require.Equal(t, tt.expectedLen, l) require.Nil(t, err) }) } } func TestParsingFailures(t *testing.T) { tests := []struct { name string input []byte expectedErr error }{ { name: "empty slice", input: []byte{}, expectedErr: io.EOF, }, { name: "2-byte encoding: not enough bytes", input: []byte{0b01000001}, expectedErr: io.ErrUnexpectedEOF, }, { name: "4-byte encoding: not enough bytes", input: []byte{0b10000000, 0x0, 0x0}, expectedErr: io.ErrUnexpectedEOF, }, { name: "8-byte encoding: not enough bytes", input: []byte{0b11000000, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, expectedErr: io.ErrUnexpectedEOF, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { value, l, err := Parse(tt.input) require.Equal(t, uint64(0), value) require.Equal(t, 0, l) require.Equal(t, tt.expectedErr, err) }) } } func TestVarintEncoding(t *testing.T) { tests := []struct { name string value uint64 expected []byte }{ {"1 byte number", 37, []byte{0x25}}, {"maximum 1 byte number", maxVarInt1, []byte{0b00111111}}, {"minimum 2 byte number", maxVarInt1 + 1, []byte{0x40, maxVarInt1 + 1}}, {"2 byte number", 15293, []byte{0b01000000 ^ 0x3b, 0xbd}}, {"maximum 2 byte number", maxVarInt2, []byte{0b01111111, 0xff}}, {"minimum 4 byte number", maxVarInt2 + 1, []byte{0b10000000, 0, 0x40, 0}}, {"4 byte number", 494878333, []byte{0b10000000 ^ 0x1d, 0x7f, 0x3e, 0x7d}}, {"maximum 4 byte number", maxVarInt4, []byte{0b10111111, 0xff, 0xff, 0xff}}, {"minimum 8 byte number", maxVarInt4 + 1, []byte{0b11000000, 0, 0, 0, 0x40, 0, 0, 0}}, {"8 byte number", 151288809941952652, []byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}}, {"maximum 8 byte number", maxVarInt8, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.Equal(t, tt.expected, Append(nil, tt.value)) }) } t.Run("panics when given a too large number (> 62 bit)", func(t *testing.T) { require.PanicsWithError(t, fmt.Sprintf("value doesn't fit into 62 bits: %d", maxVarInt8+1), func() { Append(nil, maxVarInt8+1) }, ) }) } func TestAppendWithLen(t *testing.T) { tests := []struct { name string value uint64 length int expected []byte }{ {"1-byte number in minimal encoding", 37, 1, []byte{0x25}}, {"1-byte number in 2 bytes", 37, 2, []byte{0b01000000, 0x25}}, {"1-byte number in 4 bytes", 37, 4, []byte{0b10000000, 0, 0, 0x25}}, {"1-byte number in 8 bytes", 37, 8, []byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25}}, {"2-byte number in 4 bytes", 15293, 4, []byte{0b10000000, 0, 0x3b, 0xbd}}, {"4-byte number in 8 bytes", 494878333, 8, []byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := AppendWithLen(nil, tt.value, tt.length) require.Equal(t, tt.expected, b) if tt.length > 1 { v, n, err := Parse(b) require.NoError(t, err) require.Equal(t, tt.length, n) require.Equal(t, tt.value, v) } }) } } func TestAppendWithLenFailures(t *testing.T) { tests := []struct { name string value uint64 length int }{ {"invalid length", 25, 3}, {"too short for 2 bytes", maxVarInt1 + 1, 1}, {"too short for 4 bytes", maxVarInt2 + 1, 2}, {"too short for 8 bytes", maxVarInt4 + 1, 4}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.Panics(t, func() { AppendWithLen(nil, tt.value, tt.length) }) }) } } func TestLen(t *testing.T) { tests := []struct { name string input uint64 expected int }{ {"zero", 0, 1}, {"max 1 byte", maxVarInt1, 1}, {"min 2 bytes", maxVarInt1 + 1, 2}, {"max 2 bytes", maxVarInt2, 2}, {"min 4 bytes", maxVarInt2 + 1, 4}, {"max 4 bytes", maxVarInt4, 4}, {"min 8 bytes", maxVarInt4 + 1, 8}, {"max 8 bytes", maxVarInt8, 8}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.Equal(t, tt.expected, Len(tt.input)) }) } t.Run("panics on too large number", func(t *testing.T) { require.PanicsWithError(t, fmt.Sprintf("value doesn't fit into 62 bits: %d", maxVarInt8+1), func() { Len(maxVarInt8 + 1) }, ) }) } type benchmarkValue struct { b []byte v uint64 } func randomValues(maxValue uint64) []benchmarkValue { r := rand.New(rand.NewPCG(13, 37)) const num = 1025 bv := make([]benchmarkValue, num) for i := range num { v := r.Uint64() % maxValue bv[i].v = v bv[i].b = Append([]byte{}, v) } return bv } // using a reader that is also an io.ByteReader func BenchmarkReadBytesReader(b *testing.B) { b.Run("1-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt1), false) }) b.Run("2-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt2), false) }) b.Run("4-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt4), false) }) b.Run("8-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt8), false) }) } // using a reader that is not an io.ByteReader func BenchmarkReadSimpleReader(b *testing.B) { b.Run("1-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt1), true) }) b.Run("2-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt2), true) }) b.Run("4-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt4), true) }) b.Run("8-byte", func(b *testing.B) { benchmarkRead(b, randomValues(maxVarInt8), true) }) } // simpleReader satisfies io.Reader, but not io.ByteReader // This means that NewReader will need to wrap the reader. type simpleReader struct { io.Reader } func benchmarkRead(b *testing.B, inputs []benchmarkValue, wrapBytesReader bool) { r := bytes.NewReader([]byte{}) var vr Reader if wrapBytesReader { vr = NewReader(&simpleReader{r}) } else { vr = NewReader(r) } var i int for b.Loop() { index := i % len(inputs) i++ r.Reset(inputs[index].b) val, err := Read(vr) if err != nil { b.Fatal(err) } if val != inputs[index].v { b.Fatalf("expected %d, got %d", inputs[index].v, val) } } } func BenchmarkParse(b *testing.B) { b.Run("1-byte", func(b *testing.B) { benchmarkParse(b, randomValues(maxVarInt1)) }) b.Run("2-byte", func(b *testing.B) { benchmarkParse(b, randomValues(maxVarInt2)) }) b.Run("4-byte", func(b *testing.B) { benchmarkParse(b, randomValues(maxVarInt4)) }) b.Run("8-byte", func(b *testing.B) { benchmarkParse(b, randomValues(maxVarInt8)) }) } func benchmarkParse(b *testing.B, inputs []benchmarkValue) { var i int for b.Loop() { index := i % len(inputs) i++ val, n, err := Parse(inputs[index].b) if err != nil { b.Fatal(err) } if n != len(inputs[index].b) { b.Fatalf("expected to consume %d bytes, consumed %d", len(inputs[i].b), n) } if val != inputs[index].v { b.Fatalf("expected %d, got %d", inputs[index].v, val) } } } func BenchmarkAppend(b *testing.B) { b.Run("1-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(maxVarInt1)) }) b.Run("2-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(maxVarInt2)) }) b.Run("4-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(maxVarInt4)) }) b.Run("8-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(maxVarInt8)) }) } func benchmarkAppend(b *testing.B, inputs []benchmarkValue) { buf := make([]byte, 8) var i int for b.Loop() { buf = buf[:0] index := i % len(inputs) i++ buf = Append(buf, inputs[index].v) if !bytes.Equal(buf, inputs[index].b) { b.Fatalf("expected to write %v, wrote %v", inputs[index].b, buf) } } } func BenchmarkAppendWithLen(b *testing.B) { b.Run("1-byte", func(b *testing.B) { benchmarkAppendWithLen(b, randomValues(maxVarInt1)) }) b.Run("2-byte", func(b *testing.B) { benchmarkAppendWithLen(b, randomValues(maxVarInt2)) }) b.Run("4-byte", func(b *testing.B) { benchmarkAppendWithLen(b, randomValues(maxVarInt4)) }) b.Run("8-byte", func(b *testing.B) { benchmarkAppendWithLen(b, randomValues(maxVarInt8)) }) } func benchmarkAppendWithLen(b *testing.B, inputs []benchmarkValue) { buf := make([]byte, 8) var i int for b.Loop() { buf = buf[:0] index := i % len(inputs) i++ buf = AppendWithLen(buf, inputs[index].v, len(inputs[index].b)) if !bytes.Equal(buf, inputs[index].b) { b.Fatalf("expected to write %v, wrote %v", inputs[index].b, buf) } } } quic-go-0.59.0/receive_stream.go000066400000000000000000000362271513066070600165130ustar00rootroot00000000000000package quic import ( "fmt" "io" "sync" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) // A ReceiveStream is a unidirectional Receive Stream. type ReceiveStream struct { mutex sync.Mutex streamID protocol.StreamID sender streamSender frameQueue *frameSorter finalOffset protocol.ByteCount currentFrame []byte currentFrameDone func() readPosInFrame int currentFrameIsLast bool // is the currentFrame the last frame on this stream queuedStopSending bool queuedMaxStreamData bool // Set once we read the io.EOF or the cancellation error. // Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet. errorRead bool completed bool // set once we've called streamSender.onStreamCompleted cancelledRemotely bool cancelledLocally bool cancelErr *StreamError closeForShutdownErr error readPos protocol.ByteCount reliableSize protocol.ByteCount readChan chan struct{} readOnce chan struct{} // cap: 1, to protect against concurrent use of Read deadline monotime.Time flowController flowcontrol.StreamFlowController } var ( _ streamControlFrameGetter = &ReceiveStream{} _ receiveStreamFrameHandler = &ReceiveStream{} ) func newReceiveStream( streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, ) *ReceiveStream { return &ReceiveStream{ streamID: streamID, sender: sender, flowController: flowController, frameQueue: newFrameSorter(), readChan: make(chan struct{}, 1), readOnce: make(chan struct{}, 1), finalOffset: protocol.MaxByteCount, } } // StreamID returns the stream ID. func (s *ReceiveStream) StreamID() protocol.StreamID { return s.streamID } // Read reads data from the stream. // Read can be made to time out using [ReceiveStream.SetReadDeadline]. // If the stream was canceled, the error is a [StreamError]. func (s *ReceiveStream) Read(p []byte) (int, error) { // Concurrent use of Read is not permitted (and doesn't make any sense), // but sometimes people do it anyway. // Make sure that we only execute one call at any given time to avoid hard to debug failures. s.readOnce <- struct{}{} defer func() { <-s.readOnce }() s.mutex.Lock() queuedStreamWindowUpdate, queuedConnWindowUpdate, n, err := s.readImpl(p) completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } if queuedStreamWindowUpdate { s.sender.onHasStreamControlFrame(s.streamID, s) } if queuedConnWindowUpdate { s.sender.onHasConnectionData() } return n, err } func (s *ReceiveStream) isNewlyCompleted() bool { if s.completed { return false } // We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting. if s.finalOffset == protocol.MaxByteCount { return false } // We're done with the stream if it was cancelled locally... if s.cancelledLocally { s.completed = true return true } // ... or if the error (either io.EOF or the reset error) was read if s.errorRead { s.completed = true return true } return false } func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnWindowUpdate bool, _ int, _ error) { if s.currentFrameIsLast && s.currentFrame == nil { s.errorRead = true return false, false, 0, io.EOF } if s.cancelledLocally || s.isRemoteCancellationEffective() { s.errorRead = true return false, false, 0, s.cancelErr } if s.closeForShutdownErr != nil { return false, false, 0, s.closeForShutdownErr } var bytesRead int var deadlineTimer *time.Timer for bytesRead < len(p) { if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { s.dequeueNextFrame() } if s.currentFrame == nil && bytesRead > 0 { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } for { // Stop waiting on errors if s.closeForShutdownErr != nil { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } if s.cancelledLocally || s.isRemoteCancellationEffective() { s.errorRead = true return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr } deadline := s.deadline if !deadline.IsZero() && !monotime.Now().Before(deadline) { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, errDeadline } if s.currentFrame != nil || s.currentFrameIsLast { break } s.mutex.Unlock() if deadline.IsZero() { <-s.readChan } else { if deadlineTimer == nil { deadlineTimer = time.NewTimer(monotime.Until(deadline)) defer deadlineTimer.Stop() } else { deadlineTimer.Reset(monotime.Until(deadline)) } select { case <-s.readChan: case <-deadlineTimer.C: } } s.mutex.Lock() s.dequeueNextFrame() } if bytesRead > len(p) { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) } if s.readPosInFrame > len(s.currentFrame) { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) // when a RESET_STREAM was received, the flow controller was already // informed about the final offset for this stream if !s.isRemoteCancellationEffective() { hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m)) if hasStream { s.queuedMaxStreamData = true hasStreamWindowUpdate = true } if hasConn { hasConnWindowUpdate = true } } s.readPosInFrame += m s.readPos += protocol.ByteCount(m) bytesRead += m if s.isRemoteCancellationEffective() { s.flowController.Abandon() } if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { s.currentFrame = nil if s.currentFrameDone != nil { s.currentFrameDone() } s.errorRead = true return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF } } if s.isRemoteCancellationEffective() { s.errorRead = true return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr } return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil } // isRemoteCancellationEffective returns whether the stream was cancelled remotely // and all reliable data has been read. func (s *ReceiveStream) isRemoteCancellationEffective() bool { return s.cancelledRemotely && s.readPos >= s.reliableSize } // Peek fills b with stream data, without consuming the stream data. // It blocks until len(b) bytes are available, or an error occurs. // It respects the stream deadline set by SetReadDeadline. // If the stream ends before len(b) bytes are available, // it returns the number of bytes peeked along with io.EOF. func (s *ReceiveStream) Peek(b []byte) (int, error) { if len(b) == 0 { return 0, nil } // prevent concurrent use with Read s.readOnce <- struct{}{} defer func() { <-s.readOnce }() return s.peekImpl(b) } func (s *ReceiveStream) peekImpl(b []byte) (int, error) { s.mutex.Lock() defer s.mutex.Unlock() var deadlineTimer *time.Timer for { if s.currentFrameIsLast && s.currentFrame == nil { return 0, io.EOF } if s.cancelledLocally || s.isRemoteCancellationEffective() { return 0, s.cancelErr } if s.closeForShutdownErr != nil { return 0, s.closeForShutdownErr } deadline := s.deadline if !deadline.IsZero() && !monotime.Now().Before(deadline) { return 0, errDeadline } if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { s.dequeueNextFrame() } if s.currentFrame != nil && s.readPosInFrame < len(s.currentFrame) { availableInCurrentFrame := len(s.currentFrame) - s.readPosInFrame if availableInCurrentFrame >= len(b) { copy(b, s.currentFrame[s.readPosInFrame:]) return len(b), nil } offset := s.readPos + protocol.ByteCount(availableInCurrentFrame) // First peek, then copy. // This avoids copying data if there's not enough data in the queue. if err := s.frameQueue.Peek(offset, b[availableInCurrentFrame:]); err == nil { copy(b[:availableInCurrentFrame], s.currentFrame[s.readPosInFrame:]) return len(b), nil } if s.currentFrameIsLast { copy(b[:availableInCurrentFrame], s.currentFrame[s.readPosInFrame:]) return availableInCurrentFrame, io.EOF } // If the stream was remotely cancelled and the request extends beyond the reliable size, // return the data available with the cancel error (once it's all received). if s.cancelledRemotely && s.readPos+protocol.ByteCount(len(b)) > s.reliableSize { total := int(s.reliableSize - s.readPos) needed := total - availableInCurrentFrame // only return once all available data is contiguous if needed <= 0 || s.frameQueue.Peek(offset, b[availableInCurrentFrame:total]) == nil { copy(b[:availableInCurrentFrame], s.currentFrame[s.readPosInFrame:]) return total, s.cancelErr } } // If the request extends beyond the stream's final offset, // return the data available with EOF (once it's all received). if s.readPos+protocol.ByteCount(len(b)) > s.finalOffset { total := int(s.finalOffset - s.readPos) needed := total - availableInCurrentFrame // only return once all available data is contiguous if needed <= 0 || s.frameQueue.Peek(offset, b[availableInCurrentFrame:total]) == nil { copy(b[:availableInCurrentFrame], s.currentFrame[s.readPosInFrame:]) return total, io.EOF } } } if s.currentFrameIsLast || s.readPos >= s.finalOffset { return 0, io.EOF } s.mutex.Unlock() if deadline.IsZero() { <-s.readChan } else { if deadlineTimer == nil { deadlineTimer = time.NewTimer(monotime.Until(deadline)) defer deadlineTimer.Stop() } else { deadlineTimer.Reset(monotime.Until(deadline)) } select { case <-s.readChan: case <-deadlineTimer.C: } } s.mutex.Lock() if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { s.dequeueNextFrame() } } } func (s *ReceiveStream) dequeueNextFrame() { var offset protocol.ByteCount // We're done with the last frame. Release the buffer. if s.currentFrameDone != nil { s.currentFrameDone() } offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset && !s.cancelledRemotely s.readPosInFrame = 0 } // CancelRead aborts receiving on this stream. // It instructs the peer to stop transmitting stream data. // Read will unblock immediately, and future Read calls will fail. // When called multiple times or after reading the io.EOF it is a no-op. func (s *ReceiveStream) CancelRead(errorCode StreamErrorCode) { s.mutex.Lock() queuedNewControlFrame := s.cancelReadImpl(errorCode) completed := s.isNewlyCompleted() s.mutex.Unlock() if queuedNewControlFrame { s.sender.onHasStreamControlFrame(s.streamID, s) } if completed { s.flowController.Abandon() s.sender.onStreamCompleted(s.streamID) } } func (s *ReceiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNewControlFrame bool) { if s.cancelledLocally { // duplicate call to CancelRead return false } if s.closeForShutdownErr != nil { return false } s.cancelledLocally = true if s.errorRead || s.cancelledRemotely { return false } s.queuedStopSending = true s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} s.signalRead() return true } func (s *ReceiveStream) handleStreamFrame(frame *wire.StreamFrame, now monotime.Time) error { s.mutex.Lock() err := s.handleStreamFrameImpl(frame, now) completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.flowController.Abandon() s.sender.onStreamCompleted(s.streamID) } return err } func (s *ReceiveStream) handleStreamFrameImpl(frame *wire.StreamFrame, now monotime.Time) error { maxOffset := frame.Offset + frame.DataLen() if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin, now); err != nil { return err } if frame.Fin { s.finalOffset = maxOffset } if s.cancelledLocally { return nil } if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { return err } s.signalRead() return nil } func (s *ReceiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame, now monotime.Time) error { s.mutex.Lock() err := s.handleResetStreamFrameImpl(frame, now) completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } return err } func (s *ReceiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame, now monotime.Time) error { if s.closeForShutdownErr != nil { return nil } if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true, now); err != nil { return err } s.finalOffset = frame.FinalSize // senders are allowed to reduce the reliable size, but frames might have been reordered if (!s.cancelledRemotely && s.reliableSize == 0) || frame.ReliableSize < s.reliableSize { s.reliableSize = frame.ReliableSize } if s.readPos >= s.reliableSize { // calling Abandon multiple times is a no-op s.flowController.Abandon() } // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) if s.cancelledRemotely { return nil } // don't save the error if the RESET_STREAM frames was received after CancelRead was called if s.cancelledLocally { return nil } s.cancelledRemotely = true s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true} s.signalRead() return nil } func (s *ReceiveStream) getControlFrame(now monotime.Time) (_ ackhandler.Frame, ok, hasMore bool) { s.mutex.Lock() defer s.mutex.Unlock() if !s.queuedStopSending && !s.queuedMaxStreamData { return ackhandler.Frame{}, false, false } if s.queuedStopSending { s.queuedStopSending = false return ackhandler.Frame{ Frame: &wire.StopSendingFrame{StreamID: s.streamID, ErrorCode: s.cancelErr.ErrorCode}, }, true, s.queuedMaxStreamData } s.queuedMaxStreamData = false return ackhandler.Frame{ Frame: &wire.MaxStreamDataFrame{ StreamID: s.streamID, MaximumStreamData: s.flowController.GetWindowUpdate(now), }, }, true, false } // SetReadDeadline sets the deadline for future Read calls and // any currently-blocked Read call. // A zero value for t means Read will not time out. func (s *ReceiveStream) SetReadDeadline(t time.Time) error { s.mutex.Lock() s.deadline = monotime.FromTime(t) s.mutex.Unlock() s.signalRead() return nil } // CloseForShutdown closes a stream abruptly. // It makes Read unblock (and return the error) immediately. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET. func (s *ReceiveStream) closeForShutdown(err error) { s.mutex.Lock() s.closeForShutdownErr = err s.mutex.Unlock() s.signalRead() } // signalRead performs a non-blocking send on the readChan func (s *ReceiveStream) signalRead() { select { case s.readChan <- struct{}{}: default: } } quic-go-0.59.0/receive_stream_test.go000066400000000000000000001132451513066070600175460ustar00rootroot00000000000000package quic import ( "fmt" "io" "os" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type readerWithTimeout struct { io.Reader Timeout time.Duration } func (r *readerWithTimeout) Read(p []byte) (n int, err error) { done := make(chan struct{}) go func() { defer close(done) n, err = r.Reader.Read(p) }() select { case <-done: return n, err case <-time.After(r.Timeout): return 0, fmt.Errorf("read timeout after %s", r.Timeout) } } type peeker interface { Peek(b []byte) (int, error) } type peekerWithTimeout struct { Peeker peeker Timeout time.Duration } func (p *peekerWithTimeout) Peek(b []byte) (n int, err error) { done := make(chan struct{}) go func() { defer close(done) n, err = p.Peeker.Peek(b) }() select { case <-done: return n, err case <-time.After(p.Timeout): return 0, fmt.Errorf("peek timeout after %s", p.Timeout) } } func TestReceiveStreamReadData(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newReceiveStream(42, nil, mockFC) // read an entire frame now := monotime.Now() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) b := make([]byte, 4) n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 4, n) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b) // split a frame across multiple reads mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false, now) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 4, Data: []byte{0xca, 0xfe, 0xba, 0xbe}}, now)) b = make([]byte, 2) n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, []byte{0xca, 0xfe}, b) n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, []byte{0xba, 0xbe}, b) // combine two frames gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(11), false, now), mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(14), false, now), mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Times(2), ) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 8, Data: []byte{'f', 'o', 'o'}}, now)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 11, Data: []byte{'b', 'a', 'r'}}, now)) b = make([]byte, 6) n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, []byte{'f', 'o', 'o', 'b', 'a', 'r'}, b) // reordered frames gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(20), false, now), mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(17), false, now), mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Times(2), ) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 17, Data: []byte{'b', 'a', 'z'}}, now)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 14, Data: []byte{'f', 'o', 'o'}}, now)) b = make([]byte, 6) n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, []byte{'f', 'o', 'o', 'b', 'a', 'z'}, b) } func TestReceiveStreamPeekData(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockFC.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() mockFC.EXPECT().AddBytesRead(gomock.Any()).AnyTimes() str := newReceiveStream(42, nil, mockFC) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foo")}, monotime.Now())) b := make([]byte, 2) n, err := (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, []byte("fo"), b) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("bar"), Offset: 3}, monotime.Now())) b = make([]byte, 6) n, err = (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, []byte("foobar"), b) _, err = str.Read([]byte{0, 0}) require.NoError(t, err) b = make([]byte, 2) n, err = (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, []byte("ob"), b) b = make([]byte, 4) n, err = (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.NoError(t, err) require.Equal(t, 4, n) require.Equal(t, []byte("obar"), b) } func TestReceiveStreamBlockRead(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, gomock.Any()) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) errChan := make(chan error, 1) start := monotime.Now() go func() { frame := &wire.StreamFrame{Data: []byte{0xde, 0xad}} time.Sleep(time.Hour) errChan <- str.handleStreamFrame(frame, monotime.Now()) }() n, err := (&readerWithTimeout{Reader: str, Timeout: 2 * time.Hour}).Read(make([]byte, 2)) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, time.Hour, monotime.Since(start)) require.NoError(t, <-errChan) }) } func TestReceiveStreamBlockPeek(t *testing.T) { t.Run("single STREAM frame", func(t *testing.T) { testReceiveStreamBlockPeek(t, false) }) t.Run("multiple STREAM frames", func(t *testing.T) { testReceiveStreamBlockPeek(t, true) }) } func testReceiveStreamBlockPeek(t *testing.T, smallWrites bool) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockFC.EXPECT().UpdateHighestReceived(gomock.Any(), false, gomock.Any()).AnyTimes() mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) errChan := make(chan error, 2) start := monotime.Now() go func() { if smallWrites { time.Sleep(30 * time.Minute) errChan <- str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foo")}, monotime.Now()) time.Sleep(30 * time.Minute) errChan <- str.handleStreamFrame(&wire.StreamFrame{Offset: 3, Data: []byte("bar")}, monotime.Now()) } else { time.Sleep(time.Hour) errChan <- str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now()) } }() b := make([]byte, 6) n, err := (&peekerWithTimeout{Peeker: str, Timeout: 2 * time.Hour}).Peek(b) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, []byte("foobar"), b) require.Equal(t, time.Hour, monotime.Since(start)) require.NoError(t, <-errChan) if smallWrites { require.NoError(t, <-errChan) } }) } func TestReceiveStreamReadOverlappingData(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newReceiveStream(42, nil, mockFC) // receive the same frame multiple times now := monotime.Now() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now).Times(3) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) for range 3 { require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) } b := make([]byte, 4) n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 4, n) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b) // receive overlapping data gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(8), false, now), mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), false, now), mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)), mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), ) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 4, Data: []byte("foob")}, now)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 6, Data: []byte("obar")}, now)) b = make([]byte, 6) n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(b) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, []byte("foobar"), b) } func TestReceiveStreamFlowControlUpdates(t *testing.T) { t.Run("stream", func(t *testing.T) { testReceiveStreamFlowControlUpdates(t, true, false) }) t.Run("connection", func(t *testing.T) { testReceiveStreamFlowControlUpdates(t, false, true) }) } func testReceiveStreamFlowControlUpdates(t *testing.T, hasStreamWindowUpdate, hasConnWindowUpdate bool) { const streamID protocol.StreamID = 42 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(streamID, mockSender, mockFC) now := monotime.Now() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), false, now) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad, 0xbe, 0xef}}, now)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)).Return(hasStreamWindowUpdate, hasConnWindowUpdate) if hasStreamWindowUpdate { mockSender.EXPECT().onHasStreamControlFrame(streamID, str) } if hasConnWindowUpdate { mockSender.EXPECT().onHasConnectionData() } n, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 3)) require.NoError(t, err) require.Equal(t, 3, n) require.True(t, mockCtrl.Satisfied()) if hasStreamWindowUpdate { now = now.Add(time.Second) mockFC.EXPECT().GetWindowUpdate(now).Return(protocol.ByteCount(1337)) f, ok, hasMore := str.getControlFrame(now) require.True(t, ok) require.Equal(t, &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}, f.Frame) require.False(t, hasMore) } if hasConnWindowUpdate { _, ok, hasMore := str.getControlFrame(now) require.False(t, ok) require.False(t, hasMore) } } func TestReceiveStreamDeadlineInThePast(t *testing.T) { t.Run("read", func(t *testing.T) { testReceiveStreamDeadlineInThePast(t, true, func(str *ReceiveStream, b []byte) (int, error) { return str.Read(b) }) }) t.Run("peek", func(t *testing.T) { testReceiveStreamDeadlineInThePast(t, false, func(str *ReceiveStream, b []byte) (int, error) { return str.Peek(b) }) }) } func testReceiveStreamDeadlineInThePast(t *testing.T, consumesBytes bool, op func(*ReceiveStream, []byte) (int, error)) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newReceiveStream(42, nil, mockFC) // no data is read when the deadline is in the past mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()).AnyTimes() require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now())) require.NoError(t, str.SetReadDeadline(time.Now().Add(-time.Second))) b := make([]byte, 6) n, err := op(str, b) require.Error(t, err) require.Zero(t, n) require.ErrorIs(t, err, errDeadline) // data is read when the deadline is in the future require.NoError(t, str.SetReadDeadline(time.Now().Add(time.Second))) if consumesBytes { mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) } n, err = op(str, b) require.NoError(t, err) require.Equal(t, 6, n) } func TestReceiveStreamDeadlineRemoval(t *testing.T) { t.Run("read", func(t *testing.T) { testReceiveStreamDeadlineRemoval(t, func(str *ReceiveStream) error { _, err := str.Read([]byte{0}) return err }) }) t.Run("peek", func(t *testing.T) { testReceiveStreamDeadlineRemoval(t, func(str *ReceiveStream) error { _, err := str.Peek([]byte{0}) return err }) }) } func testReceiveStreamDeadlineRemoval(t *testing.T, op func(*ReceiveStream) error) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newReceiveStream(42, nil, mockFC) const deadline = time.Minute require.NoError(t, str.SetReadDeadline(time.Now().Add(deadline))) errChan := make(chan error, 1) go func() { errChan <- op(str) }() select { case err := <-errChan: t.Fatalf("should not have returned yet: %v", err) case <-time.After(deadline / 2): } // remove the deadline after a while (but before it expires) require.NoError(t, str.SetReadDeadline(time.Time{})) // no deadline set: should not return at all select { case err := <-errChan: t.Fatalf("should not have returned yet: %v", err) case <-time.After(2 * deadline): } // now set the deadline to the past to make it return immediately require.NoError(t, str.SetReadDeadline(time.Now().Add(-time.Second))) synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, os.ErrDeadlineExceeded) default: t.Fatal("timeout") } }) } func TestReceiveStreamDeadlineExtension(t *testing.T) { t.Run("read", func(t *testing.T) { testReceiveStreamDeadlineExtension(t, func(str *ReceiveStream) error { _, err := str.Read([]byte{0}) return err }) }) t.Run("peek", func(t *testing.T) { testReceiveStreamDeadlineExtension(t, func(str *ReceiveStream) error { _, err := str.Peek([]byte{0}) return err }) }) } func testReceiveStreamDeadlineExtension(t *testing.T, op func(*ReceiveStream) error) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newReceiveStream(42, nil, mockFC) start := monotime.Now() deadline := 5 * time.Second require.NoError(t, str.SetReadDeadline(time.Now().Add(deadline))) errChan := make(chan error, 1) go func() { errChan <- op(str) }() select { case err := <-errChan: t.Fatalf("should not have returned yet: %v", err) case <-time.After(deadline / 2): } // extend the deadline require.NoError(t, str.SetReadDeadline(time.Now().Add(deadline))) select { case err := <-errChan: require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Equal(t, start.Add(deadline*3/2), monotime.Now()) case <-time.After(deadline + time.Nanosecond): t.Fatal("timeout") } }) } func TestReceiveStreamEOFWithData(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) now := monotime.Now() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true, now) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), false, now) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 2, Data: []byte{0xbe, 0xef}, Fin: true}, now)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte{0xde, 0xad}}, now)) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)).Times(2) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) // peeking doesn't return an EOF b := make([]byte, 4) n, err := (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.NoError(t, err) require.Equal(t, 4, n) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b) // peeking returns the EOF, if more data is being peeked b = make([]byte, 6) n, err = (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.ErrorIs(t, err, io.EOF) require.Equal(t, 4, n) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b[:n]) // reading returns the EOF strWithTimeout := &readerWithTimeout{Reader: str, Timeout: time.Second} b = make([]byte, 6) n, err = strWithTimeout.Read(b) require.ErrorIs(t, err, io.EOF) require.Equal(t, 4, n) require.Equal(t, []byte{0xde, 0xad, 0xbe, 0xef}, b[:n]) n, err = strWithTimeout.Read(b) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) } func TestReceiveStreamPeekEOF(t *testing.T) { t.Run("long peek", func(t *testing.T) { testReceiveStreamPeekEOF(t, true) }) t.Run("exact peek", func(t *testing.T) { testReceiveStreamPeekEOF(t, false) }) } func testReceiveStreamPeekEOF(t *testing.T, longPeek bool) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Offset: 3, Data: []byte("bar"), Fin: true}, monotime.Now())) type result struct { err error data []byte } resultChan := make(chan result, 1) go func() { b := make([]byte, 6) if longPeek { b = make([]byte, 8) } n, err := (&peekerWithTimeout{Peeker: str, Timeout: time.Hour}).Peek(b) resultChan <- result{err: err, data: b[:n]} }() synctest.Wait() select { case result := <-resultChan: t.Fatalf("peek should not have returned yet: %v", result.err) default: } require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("f")}, monotime.Now())) synctest.Wait() select { case result := <-resultChan: t.Fatalf("peek should not have returned yet: %v", result.err) default: } require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("oo"), Offset: 1}, monotime.Now())) synctest.Wait() select { case result := <-resultChan: if longPeek { assert.ErrorIs(t, result.err, io.EOF) } else { assert.NoError(t, result.err) } require.Equal(t, []byte("foobar"), result.data) default: t.Fatal("peek should have returned") } }) } func TestReceiveStreamImmediateFINs(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true, gomock.Any()) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Fin: true}, monotime.Now())) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) // peeking returns the EOF n, err := (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(make([]byte, 4)) require.ErrorIs(t, err, io.EOF) require.Equal(t, 0, n) // and so does reading n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 4)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) } func TestReceiveStreamCloseForShutdown(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) strWithTimeout := &readerWithTimeout{Reader: str, Timeout: time.Minute} // Test immediate return of reads readErrChan := make(chan error, 1) peekErrChan := make(chan error, 1) go func() { _, err := strWithTimeout.Read([]byte{0}) readErrChan <- err }() go func() { _, err := (&peekerWithTimeout{Peeker: str, Timeout: time.Minute}).Peek([]byte{0}) peekErrChan <- err }() synctest.Wait() select { case err := <-readErrChan: t.Fatalf("read returned before closeForShutdown: %v", err) case err := <-peekErrChan: t.Fatalf("peek returned before closeForShutdown: %v", err) default: } str.closeForShutdown(assert.AnError) synctest.Wait() select { case err := <-readErrChan: require.ErrorIs(t, err, assert.AnError) default: t.Fatal("read should have returned") } select { case err := <-peekErrChan: require.ErrorIs(t, err, assert.AnError) default: t.Fatal("peek should have returned") } // following calls to Peek should return the error n, err := (&peekerWithTimeout{Peeker: str, Timeout: time.Minute}).Peek([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, assert.AnError) // following calls to Read should return the error n, err = strWithTimeout.Read([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, assert.AnError) // receiving a RESET_STREAM frame after closeForShutdown does nothing require.NoError(t, str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42}, monotime.Now())) n, err = strWithTimeout.Read([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, assert.AnError) // calling CancelRead after closeForShutdown does nothing str.CancelRead(1234) n, err = strWithTimeout.Read([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, assert.AnError) }) } func TestReceiveStreamCancellation(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) strWithTimeout := &readerWithTimeout{Reader: str, Timeout: 2 * time.Second} mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), gomock.Any()) readErrChan := make(chan error, 1) peekErrChan := make(chan error, 1) go func() { _, err := strWithTimeout.Read([]byte{0}) readErrChan <- err }() go func() { _, err := (&peekerWithTimeout{Peeker: str, Timeout: 2 * time.Second}).Peek([]byte{0}) peekErrChan <- err }() synctest.Wait() str.CancelRead(1234) // this queues a STOP_SENDING frame f, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1234}, f.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) synctest.Wait() select { case err := <-readErrChan: require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) default: t.Fatal("Read was not unblocked") } select { case err := <-peekErrChan: require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) default: t.Fatal("Peek was not unblocked") } // further calls to Peek return the error n, err := (&peekerWithTimeout{Peeker: str, Timeout: 2 * time.Second}).Peek([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) // further Read calls return the error n, err = strWithTimeout.Read([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) // calling CancelRead again does nothing // especially: // 1. no more calls to onHasStreamControlFrame // 2. no changes of the error code returned by Read str.CancelRead(1234) str.CancelRead(4321) n, err = strWithTimeout.Read([]byte{0}) require.Zero(t, n) // error code unchanged require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) require.True(t, mockCtrl.Satisfied()) // receiving the FIN bit has no effect mockFC.EXPECT().Abandon() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()).Times(2) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) // receive two of them, to make sure onStreamCompleted is not called twice require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, monotime.Now())) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, monotime.Now())) require.True(t, mockCtrl.Satisfied()) // receiving a RESET_STREAM frame after CancelRead has no effect mockFC.EXPECT().Abandon() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()) require.NoError(t, str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 4321, FinalSize: 42}, monotime.Now())) n, err = strWithTimeout.Read([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) }) } func TestReceiveStreamCancelReadAfterFIN(t *testing.T) { t.Run("FIN not read", func(t *testing.T) { testReceiveStreamCancelReadAfterFIN(t, false) }) t.Run("FIN read", func(t *testing.T) { testReceiveStreamCancelReadAfterFIN(t, true) }) } func testReceiveStreamCancelReadAfterFIN(t *testing.T, finRead bool) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, monotime.Now())) if finRead { mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) n, err := str.Read(make([]byte, 10)) require.ErrorIs(t, err, io.EOF) require.Equal(t, 6, n) } // if the FIN was received, but not read yet, a STOP_SENDING frame is queued if !finRead { mockFC.EXPECT().Abandon() mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str) } str.CancelRead(1337) f, ok, hasMore := str.getControlFrame(monotime.Now()) // if the EOF was already read, no STOP_SENDING frame is queued if finRead { require.False(t, ok) require.False(t, hasMore) } else { require.True(t, ok) require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame) require.False(t, hasMore) } // Read returns the error... n, err := str.Read([]byte{0}) require.Zero(t, n) // ... and Peek returns the same error n, peekErr := (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek([]byte{0}) require.Zero(t, n) if finRead { assert.ErrorIs(t, err, io.EOF) assert.ErrorIs(t, peekErr, io.EOF) } else { assert.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) assert.ErrorIs(t, peekErr, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) } } func TestReceiveStreamReset(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) strWithTimeout := &readerWithTimeout{Reader: str, Timeout: 2 * time.Second} readErrChan := make(chan error, 1) peekErrChan := make(chan error, 1) go func() { _, err := strWithTimeout.Read([]byte{0}) readErrChan <- err }() go func() { _, err := (&peekerWithTimeout{Peeker: str, Timeout: 2 * time.Second}).Peek([]byte{0}) peekErrChan <- err }() synctest.Wait() mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()), mockFC.EXPECT().Abandon().MinTimes(1), ) require.NoError(t, str.handleResetStreamFrame( &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42}, monotime.Now(), )) synctest.Wait() select { case err := <-readErrChan: require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) default: t.Fatal("Read was not unblocked") } select { case err := <-peekErrChan: require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) default: t.Fatal("Peek was not unblocked") } // further calls to Peek return the error n, err := (&peekerWithTimeout{Peeker: str, Timeout: 2 * time.Second}).Peek([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) // further calls to Read return the error _, err = strWithTimeout.Read([]byte{0}) require.Equal(t, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}, err) // further RESET_STREAM frames have no effect mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()) require.NoError(t, str.handleResetStreamFrame( &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 4321, FinalSize: 42}, monotime.Now(), )) n, err = str.Read([]byte{0}) require.Zero(t, n) // error code unchanged require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) // CancelRead after a RESET_STREAM frame has no effect str.CancelRead(100) n, err = str.Read([]byte{0}) require.Zero(t, n) // error code and remote flag unchanged require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: true}) }) } func TestReceiveStreamResetAfterFINRead(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) require.NoError(t, str.handleStreamFrame( &wire.StreamFrame{StreamID: 42, Data: []byte("foobar"), Fin: true}, monotime.Now(), )) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) n, err := str.Read(make([]byte, 6)) require.Equal(t, 6, n) require.ErrorIs(t, err, io.EOF) // make sure that onStreamCompleted was called due to the EOF require.True(t, mockCtrl.Satisfied()) // Now receive a RESET_STREAM frame. // We don't expect any more calls to onStreamCompleted. mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) mockFC.EXPECT().Abandon() require.NoError(t, str.handleResetStreamFrame( &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 6}, monotime.Now(), )) // now read the error n, err = str.Read([]byte{0}) require.Error(t, err) require.Zero(t, n) } // Calling Read concurrently doesn't make any sense (and is forbidden), // but we still want to make sure that we don't complete the stream more than once // if the user misuses our API. // This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), // which can be hard to debug. // Note that even without the protection built into the receiveStream, this test // is very timing-dependent, and would need to run a few hundred times to trigger the failure. func TestReceiveStreamConcurrentReads(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any(), gomock.Any()).AnyTimes() var bytesRead protocol.ByteCount mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) (bool, bool) { bytesRead += n return false, false }).AnyTimes() var numCompleted atomic.Int32 mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)).Do(func(protocol.StreamID) { numCompleted.Add(1) }).AnyTimes() const num = 3 errChan := make(chan error, num) for range num { go func() { _, err := str.Read(make([]byte, 8)) errChan <- err }() } require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, monotime.Now())) synctest.Wait() for range num { select { case err := <-errChan: require.ErrorIs(t, err, io.EOF) default: t.Fatal("read should have returned") } } require.Equal(t, protocol.ByteCount(6), bytesRead) require.Equal(t, int32(1), numCompleted.Load()) }) } func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now())) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) b := make([]byte, 3) n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("foo"), b) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) mockFC.EXPECT().Abandon() str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, monotime.Now()) require.True(t, mockCtrl.Satisfied()) // Peek returns the error n, err = str.Peek([]byte{0}) require.Zero(t, n) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) // Read returns the error mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) n, err = str.Read([]byte{0}) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) require.Zero(t, n) } func TestReceiveStreamResetStreamAtAfterReadOffset(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now())) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) b := make([]byte, 2) n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, []byte("fo"), b) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, monotime.Now()) require.True(t, mockCtrl.Satisfied()) // Peek returns no error when peeking up to the reliable size... b = make([]byte, 4) n, err = (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.NoError(t, err) require.Equal(t, 4, n) require.Equal(t, []byte("obar"), b) // ... but returns the error when peeking beyond the reliable size b = make([]byte, 5) n, err = (&peekerWithTimeout{Peeker: str, Timeout: time.Second}).Peek(b) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) require.Equal(t, 4, n) require.Equal(t, []byte("obar"), b[:n]) // Read returns the error after reading up to the reliable size b = make([]byte, 2) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) n, err = str.Read(b) require.NoError(t, err) require.Equal(t, 2, n) require.Equal(t, []byte("ob"), b) require.True(t, mockCtrl.Satisfied()) gomock.InOrder( mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), mockFC.EXPECT().Abandon(), ) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) n, err = str.Read(b) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) require.Equal(t, 2, n) require.Equal(t, []byte("ar"), b) } func TestReceiveStreamMultipleResetStreamAt(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now())) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) b := make([]byte, 3) n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("foo"), b) require.True(t, mockCtrl.Satisfied()) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, monotime.Now()) require.True(t, mockCtrl.Satisfied()) // receiving a reordered RESET_STREAM_AT frame has no effect mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, monotime.Now()) require.True(t, mockCtrl.Satisfied()) // receiving a RESET_STREAM_AT frame with a smaller reliable size is valid mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) mockFC.EXPECT().Abandon() str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, monotime.Now()) // Read returns the error mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) n, err = str.Read(b) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) require.Zero(t, n) } func TestReceiveStreamResetStreamAtAfterResetStream(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newReceiveStream(42, mockSender, mockFC) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now())) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) b := make([]byte, 3) n, err := str.Read(b) require.NoError(t, err) require.Equal(t, 3, n) require.Equal(t, []byte("foo"), b) require.True(t, mockCtrl.Satisfied()) mockFC.EXPECT().Abandon() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10}, monotime.Now()) require.True(t, mockCtrl.Satisfied()) // receiving a reordered RESET_STREAM_AT frame has no effect mockFC.EXPECT().Abandon() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, monotime.Now()) require.True(t, mockCtrl.Satisfied()) // Read returns the error mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) n, err = str.Read(b) require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) require.Zero(t, n) } quic-go-0.59.0/retransmission_queue.go000066400000000000000000000103471513066070600177750ustar00rootroot00000000000000package quic import ( "fmt" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) type framesToRetransmit struct { crypto []*wire.CryptoFrame other []wire.Frame } type retransmissionQueue struct { initial *framesToRetransmit handshake *framesToRetransmit appData framesToRetransmit } func newRetransmissionQueue() *retransmissionQueue { return &retransmissionQueue{ initial: &framesToRetransmit{}, handshake: &framesToRetransmit{}, } } func (q *retransmissionQueue) addInitial(f wire.Frame) { if q.initial == nil { return } if cf, ok := f.(*wire.CryptoFrame); ok { q.initial.crypto = append(q.initial.crypto, cf) return } q.initial.other = append(q.initial.other, f) } func (q *retransmissionQueue) addHandshake(f wire.Frame) { if q.handshake == nil { return } if cf, ok := f.(*wire.CryptoFrame); ok { q.handshake.crypto = append(q.handshake.crypto, cf) return } q.handshake.other = append(q.handshake.other, f) } func (q *retransmissionQueue) addAppData(f wire.Frame) { switch f := f.(type) { case *wire.StreamFrame: panic("STREAM frames are handled with their respective streams.") case *wire.CryptoFrame: q.appData.crypto = append(q.appData.crypto, f) default: q.appData.other = append(q.appData.other, f) } } func (q *retransmissionQueue) HasData(encLevel protocol.EncryptionLevel) bool { //nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets. switch encLevel { case protocol.EncryptionInitial: return q.initial != nil && (len(q.initial.crypto) > 0 || len(q.initial.other) > 0) case protocol.EncryptionHandshake: return q.handshake != nil && (len(q.handshake.crypto) > 0 || len(q.handshake.other) > 0) case protocol.Encryption1RTT: return len(q.appData.crypto) > 0 || len(q.appData.other) > 0 } return false } func (q *retransmissionQueue) GetFrame(encLevel protocol.EncryptionLevel, maxLen protocol.ByteCount, v protocol.Version) wire.Frame { var r *framesToRetransmit //nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets. switch encLevel { case protocol.EncryptionInitial: r = q.initial case protocol.EncryptionHandshake: r = q.handshake case protocol.Encryption1RTT: r = &q.appData } if r == nil { return nil } if len(r.crypto) > 0 { f := r.crypto[0] newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) if newFrame == nil && !needsSplit { // the whole frame fits r.crypto = r.crypto[1:] return f } if newFrame != nil { // frame was split. Leave the original frame in the queue. return newFrame } } if len(r.other) == 0 { return nil } f := r.other[0] if f.Length(v) > maxLen { return nil } r.other = r.other[1:] return f } func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { //nolint:exhaustive // Can only drop Initial and Handshake packet number space. switch encLevel { case protocol.EncryptionInitial: q.initial = nil case protocol.EncryptionHandshake: q.handshake = nil default: panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) } } func (q *retransmissionQueue) AckHandler(encLevel protocol.EncryptionLevel) ackhandler.FrameHandler { switch encLevel { case protocol.EncryptionInitial: return (*retransmissionQueueInitialAckHandler)(q) case protocol.EncryptionHandshake: return (*retransmissionQueueHandshakeAckHandler)(q) case protocol.Encryption0RTT, protocol.Encryption1RTT: return (*retransmissionQueueAppDataAckHandler)(q) } return nil } type retransmissionQueueInitialAckHandler retransmissionQueue func (q *retransmissionQueueInitialAckHandler) OnAcked(wire.Frame) {} func (q *retransmissionQueueInitialAckHandler) OnLost(f wire.Frame) { (*retransmissionQueue)(q).addInitial(f) } type retransmissionQueueHandshakeAckHandler retransmissionQueue func (q *retransmissionQueueHandshakeAckHandler) OnAcked(wire.Frame) {} func (q *retransmissionQueueHandshakeAckHandler) OnLost(f wire.Frame) { (*retransmissionQueue)(q).addHandshake(f) } type retransmissionQueueAppDataAckHandler retransmissionQueue func (q *retransmissionQueueAppDataAckHandler) OnAcked(wire.Frame) {} func (q *retransmissionQueueAppDataAckHandler) OnLost(f wire.Frame) { (*retransmissionQueue)(q).addAppData(f) } quic-go-0.59.0/retransmission_queue_test.go000066400000000000000000000110571513066070600210330ustar00rootroot00000000000000package quic import ( "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" ) func TestRetransmissionQueueFrames(t *testing.T) { t.Run("Initial", func(t *testing.T) { testRetransmissionQueueFrames(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testRetransmissionQueueFrames(t, protocol.EncryptionHandshake) }) t.Run("1-RTT", func(t *testing.T) { testRetransmissionQueueFrames(t, protocol.Encryption1RTT) }) } func testRetransmissionQueueFrames(t *testing.T, encLevel protocol.EncryptionLevel) { q := newRetransmissionQueue() require.False(t, q.HasData(encLevel)) require.Nil(t, q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1)) ah := q.AckHandler(encLevel) require.NotNil(t, ah) ah.OnLost(&wire.PingFrame{}) require.True(t, q.HasData(encLevel)) require.Equal(t, &wire.PingFrame{}, q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1)) require.False(t, q.HasData(encLevel)) require.Nil(t, q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1)) f := &wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} ah.OnLost(f) require.True(t, q.HasData(encLevel)) require.Nil(t, q.GetFrame(encLevel, f.Length(protocol.Version1)-1, protocol.Version1)) require.Equal(t, f, q.GetFrame(encLevel, f.Length(protocol.Version1), protocol.Version1)) require.False(t, q.HasData(encLevel)) if encLevel == protocol.Encryption1RTT { require.Panics(t, func() { ah.OnLost(&wire.StreamFrame{}) }) } } func TestRetransmissionQueueCryptoFrames(t *testing.T) { t.Run("Initial", func(t *testing.T) { testRetransmissionQueueCryptoFrames(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testRetransmissionQueueCryptoFrames(t, protocol.EncryptionHandshake) }) t.Run("1-RTT", func(t *testing.T) { testRetransmissionQueueCryptoFrames(t, protocol.Encryption1RTT) }) } func testRetransmissionQueueCryptoFrames(t *testing.T, encLevel protocol.EncryptionLevel) { q := newRetransmissionQueue() var otherEncLevel protocol.EncryptionLevel switch encLevel { case protocol.EncryptionInitial: otherEncLevel = protocol.EncryptionHandshake case protocol.EncryptionHandshake: otherEncLevel = protocol.Encryption1RTT case protocol.Encryption1RTT: otherEncLevel = protocol.EncryptionInitial } ah := q.AckHandler(encLevel) require.NotNil(t, ah) ah.OnLost(&wire.CryptoFrame{Data: []byte("foobar")}) require.True(t, q.HasData(encLevel)) require.False(t, q.HasData(otherEncLevel)) require.Equal(t, &wire.CryptoFrame{Data: []byte("foobar")}, q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1)) require.False(t, q.HasData(encLevel)) require.Nil(t, q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1)) f := &wire.CryptoFrame{Offset: 100, Data: []byte("foobar")} ah.OnLost(f) ah.OnLost(&wire.PingFrame{}) require.True(t, q.HasData(encLevel)) require.False(t, q.HasData(otherEncLevel)) // the CRYPTO frame wouldn't fit, not even if it was split require.IsType(t, &wire.PingFrame{}, q.GetFrame(encLevel, 2, protocol.Version1)) f1 := q.GetFrame(encLevel, f.Length(protocol.Version1)-3, protocol.Version1) require.NotNil(t, f1) require.IsType(t, &wire.CryptoFrame{}, f1) require.Equal(t, &wire.CryptoFrame{Offset: 100, Data: []byte("foo")}, f1) f2 := q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1) require.NotNil(t, f2) require.IsType(t, &wire.CryptoFrame{}, f2) require.Equal(t, &wire.CryptoFrame{Offset: 103, Data: []byte("bar")}, f2) } func TestRetransmissionQueueDropEncLevel(t *testing.T) { q := newRetransmissionQueue() require.Panics(t, func() { q.DropPackets(protocol.Encryption0RTT) }) require.Panics(t, func() { q.DropPackets(protocol.Encryption1RTT) }) t.Run("Initial", func(t *testing.T) { testRetransmissionQueueDropEncLevel(t, protocol.EncryptionInitial) }) t.Run("Handshake", func(t *testing.T) { testRetransmissionQueueDropEncLevel(t, protocol.EncryptionHandshake) }) } func testRetransmissionQueueDropEncLevel(t *testing.T, encLevel protocol.EncryptionLevel) { q := newRetransmissionQueue() ah := q.AckHandler(encLevel) require.NotNil(t, ah) ah.OnLost(&wire.PingFrame{}) ah.OnLost(&wire.CryptoFrame{Data: []byte("foobar")}) require.True(t, q.HasData(encLevel)) q.DropPackets(encLevel) require.False(t, q.HasData(encLevel)) require.Nil(t, q.GetFrame(encLevel, protocol.MaxByteCount, protocol.Version1)) // losing more frame is a no-op ah.OnLost(&wire.CryptoFrame{Data: []byte("foobar")}) ah.OnLost(&wire.PingFrame{}) require.False(t, q.HasData(encLevel)) } quic-go-0.59.0/send_conn.go000066400000000000000000000061071513066070600154560ustar00rootroot00000000000000package quic import ( "net" "sync/atomic" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { Write(b []byte, gsoSize uint16, ecn protocol.ECN) error WriteTo([]byte, net.Addr) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr ChangeRemoteAddr(addr net.Addr, info packetInfo) capabilities() connCapabilities } type remoteAddrInfo struct { addr net.Addr oob []byte } type sconn struct { rawConn localAddr net.Addr remoteAddrInfo atomic.Pointer[remoteAddrInfo] logger utils.Logger // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool // Used to catch the error sometimes returned by the first sendmsg call on Linux, // see https://github.com/golang/go/issues/63322. wroteFirstPacket bool } var _ sendConn = &sconn{} func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logger) *sconn { localAddr := c.LocalAddr() if info.addr.IsValid() { if udpAddr, ok := localAddr.(*net.UDPAddr); ok { addrCopy := *udpAddr addrCopy.IP = info.addr.AsSlice() localAddr = &addrCopy } } oob := info.OOB() // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) oob = append(oob, make([]byte, 64)...)[:l] sc := &sconn{ rawConn: c, localAddr: localAddr, logger: logger, } sc.remoteAddrInfo.Store(&remoteAddrInfo{ addr: remote, oob: oob, }) return sc } func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { ai := c.remoteAddrInfo.Load() err := c.writePacket(p, ai.addr, ai.oob, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true if c.logger.Debug() { c.logger.Debugf("GSO failed when sending to %s", ai.addr) } // send out the packets one by one for len(p) > 0 { l := len(p) if l > int(gsoSize) { l = int(gsoSize) } if err := c.writePacket(p[:l], ai.addr, ai.oob, 0, ecn); err != nil { return err } p = p[l:] } return nil } return err } func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16, ecn protocol.ECN) error { _, err := c.WritePacket(p, addr, oob, gsoSize, ecn) if err != nil && !c.wroteFirstPacket && isPermissionError(err) { _, err = c.WritePacket(p, addr, oob, gsoSize, ecn) } c.wroteFirstPacket = true return err } func (c *sconn) WriteTo(b []byte, addr net.Addr) error { _, err := c.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) return err } func (c *sconn) capabilities() connCapabilities { capabilities := c.rawConn.capabilities() if capabilities.GSO { capabilities.GSO = !c.gotGSOError } return capabilities } func (c *sconn) ChangeRemoteAddr(addr net.Addr, info packetInfo) { c.remoteAddrInfo.Store(&remoteAddrInfo{ addr: addr, oob: info.OOB(), }) } func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddrInfo.Load().addr } func (c *sconn) LocalAddr() net.Addr { return c.localAddr } quic-go-0.59.0/send_conn_test.go000066400000000000000000000120331513066070600165100ustar00rootroot00000000000000package quic import ( "net" "net/netip" "runtime" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) // Only if appendUDPSegmentSizeMsg actually appends a message (and isn't only a stub implementation), // GSO is actually supported on this platform. var platformSupportsGSO = len(appendUDPSegmentSizeMsg([]byte{}, 1337)) > 0 func TestSendConnLocalAndRemoteAddress(t *testing.T) { remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} rawConn := NewMockRawConn(gomock.NewController(t)) rawConn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(10, 11, 12, 13), Port: 14}).Times(2) c := newSendConn( rawConn, remoteAddr, packetInfo{addr: netip.AddrFrom4([4]byte{127, 0, 0, 42})}, utils.DefaultLogger, ) require.Equal(t, "127.0.0.42:14", c.LocalAddr().String()) require.Equal(t, remoteAddr, c.RemoteAddr()) // the local raw conn's local address is only used if we don't an address from the packet info c = newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) require.Equal(t, "10.11.12.13:14", c.LocalAddr().String()) } func TestSendConnOOB(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("we don't OOB conn on windows, and no packet info will be available") } remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} rawConn := NewMockRawConn(gomock.NewController(t)) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().AnyTimes() pi := packetInfo{addr: netip.IPv6Loopback()} rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0), protocol.ECT1) require.NotEmpty(t, pi.OOB()) c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) require.NoError(t, c.Write([]byte("foobar"), 0, protocol.ECT1)) } func TestSendConnDetectGSOFailure(t *testing.T) { if !platformSupportsGSO { t.Skip("GSO is not supported on this platform") } remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} rawConn := NewMockRawConn(gomock.NewController(t)) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).MinTimes(1) c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) gomock.InOrder( rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4), protocol.ECNCE).Return(0, errGSO), rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(4, nil), rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(2, nil), ) require.NoError(t, c.Write([]byte("foobar"), 4, protocol.ECNCE)) require.False(t, c.capabilities().GSO) } func TestSendConnSendmsgFailures(t *testing.T) { if runtime.GOOS != "linux" { t.Skip("only Linux exhibits this bug, we don't need to work around it on other platforms") } remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} t.Run("first call to sendmsg fails", func(t *testing.T) { rawConn := NewMockRawConn(gomock.NewController(t)) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) gomock.InOrder( rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), gomock.Any(), protocol.ECNCE).Return(0, errNotPermitted), rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(6, nil), ) require.NoError(t, c.Write([]byte("foobar"), 0, protocol.ECNCE)) }) t.Run("later call to sendmsg fails", func(t *testing.T) { rawConn := NewMockRawConn(gomock.NewController(t)) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), gomock.Any(), protocol.ECNCE).Return(0, errNotPermitted).Times(2) require.Error(t, c.Write([]byte("foobar"), 0, protocol.ECNCE)) }) } func TestSendConnRemoteAddrChange(t *testing.T) { ln1 := newUDPConnLocalhost(t) ln2 := newUDPConnLocalhost(t) c := newSendConn( &basicConn{PacketConn: newUDPConnLocalhost(t)}, ln1.LocalAddr(), packetInfo{}, utils.DefaultLogger, ) require.NoError(t, c.Write([]byte("foobar"), 0, protocol.ECNUnsupported)) ln1.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1024) n, err := ln1.Read(b) require.NoError(t, err) require.Equal(t, "foobar", string(b[:n])) require.NoError(t, c.WriteTo([]byte("foobaz"), ln2.LocalAddr())) ln2.SetReadDeadline(time.Now().Add(time.Second)) b = make([]byte, 1024) n, err = ln2.Read(b) require.NoError(t, err) require.Equal(t, "foobaz", string(b[:n])) c.ChangeRemoteAddr(ln2.LocalAddr(), packetInfo{}) require.NoError(t, c.Write([]byte("lorem ipsum"), 0, protocol.ECNUnsupported)) ln2.SetReadDeadline(time.Now().Add(time.Second)) b = make([]byte, 1024) n, err = ln2.Read(b) require.NoError(t, err) require.Equal(t, "lorem ipsum", string(b[:n])) } quic-go-0.59.0/send_queue.go000066400000000000000000000051031513066070600156400ustar00rootroot00000000000000package quic import ( "net" "github.com/quic-go/quic-go/internal/protocol" ) type sender interface { Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) SendProbe(*packetBuffer, net.Addr) Run() error WouldBlock() bool Available() <-chan struct{} Close() } type queueEntry struct { buf *packetBuffer gsoSize uint16 ecn protocol.ECN } type sendQueue struct { queue chan queueEntry closeCalled chan struct{} // runStopped when Close() is called runStopped chan struct{} // runStopped when the run loop returns available chan struct{} conn sendConn } var _ sender = &sendQueue{} const sendQueueCapacity = 8 func newSendQueue(conn sendConn) sender { return &sendQueue{ conn: conn, runStopped: make(chan struct{}), closeCalled: make(chan struct{}), available: make(chan struct{}, 1), queue: make(chan queueEntry, sendQueueCapacity), } } // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { select { case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { case <-h.available: default: } } case <-h.runStopped: default: panic("sendQueue.Send would have blocked") } } func (h *sendQueue) SendProbe(p *packetBuffer, addr net.Addr) { h.conn.WriteTo(p.Data, addr) } func (h *sendQueue) WouldBlock() bool { return len(h.queue) == sendQueueCapacity } func (h *sendQueue) Available() <-chan struct{} { return h.available } func (h *sendQueue) Run() error { defer close(h.runStopped) var shouldClose bool for { if shouldClose && len(h.queue) == 0 { return nil } select { case <-h.closeCalled: h.closeCalled = nil // prevent this case from being selected again // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and // 3. Eventual detection of loss PingFrame. if !isSendMsgSizeErr(err) { return err } } e.buf.Release() select { case h.available <- struct{}{}: default: } } } } func (h *sendQueue) Close() { close(h.closeCalled) // wait until the run loop returned <-h.runStopped } quic-go-0.59.0/send_queue_test.go000066400000000000000000000103011513066070600166730ustar00rootroot00000000000000package quic import ( "net" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func getPacketWithContents(b []byte) *packetBuffer { buf := getPacketBuffer() buf.Data = buf.Data[:len(b)] copy(buf.Data, b) return buf } func TestSendQueueSendOnePacket(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) c := NewMockSendConn(mockCtrl) q := newSendQueue(c) written := make(chan struct{}) c.EXPECT().Write([]byte("foobar"), uint16(10), protocol.ECT1).Do( func([]byte, uint16, protocol.ECN) error { close(written); return nil }, ) done := make(chan struct{}) go func() { q.Run() close(done) }() q.Send(getPacketWithContents([]byte("foobar")), 10, protocol.ECT1) synctest.Wait() select { case <-written: default: t.Fatal("write should have returned") } q.Close() synctest.Wait() select { case <-done: default: t.Fatal("Run should have returned") } }) } func TestSendQueueBlocking(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) c := NewMockSendConn(mockCtrl) q := newSendQueue(c) blockWrite := make(chan struct{}) written := make(chan struct{}, 1) c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do( func([]byte, uint16, protocol.ECN) error { select { case written <- struct{}{}: default: } <-blockWrite return nil }, ).AnyTimes() done := make(chan struct{}) go func() { q.Run() close(done) }() // +1, since one packet will be queued in the Write call for i := range sendQueueCapacity + 1 { require.False(t, q.WouldBlock()) q.Send(getPacketWithContents([]byte("foobar")), 10, protocol.ECT1) // make sure that the first packet is actually enqueued in the Write call if i == 0 { select { case <-written: case <-time.After(time.Second): t.Fatal("timeout") } } } require.True(t, q.WouldBlock()) select { case <-q.Available(): t.Fatal("should not be available") default: } require.Panics(t, func() { q.Send(getPacketWithContents([]byte("foobar")), 10, protocol.ECT1) }) // allow one packet to be sent blockWrite <- struct{}{} select { case <-written: case <-time.After(time.Second): t.Fatal("timeout") } select { case <-q.Available(): require.False(t, q.WouldBlock()) case <-time.After(time.Second): t.Fatal("timeout") } // when calling Close, all packets are first sent out closed := make(chan struct{}) go func() { q.Close() close(closed) }() synctest.Wait() select { case <-closed: t.Fatal("Close should have blocked") default: } for range sendQueueCapacity { blockWrite <- struct{}{} } synctest.Wait() select { case <-closed: default: t.Fatal("Close should have returned") } select { case <-done: default: t.Fatal("Run should have returned") } }) } func TestSendQueueWriteError(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) c := NewMockSendConn(mockCtrl) q := newSendQueue(c) c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(assert.AnError) q.Send(getPacketWithContents([]byte("foobar")), 6, protocol.ECNNon) errChan := make(chan error, 1) go func() { errChan <- q.Run() }() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, assert.AnError) default: t.Fatal("Run should have returned") } // further calls to Send should not block sent := make(chan struct{}) go func() { defer close(sent) for range 2 * sendQueueCapacity { q.Send(getPacketWithContents([]byte("raboof")), 6, protocol.ECNNon) } }() synctest.Wait() select { case <-sent: default: t.Fatal("Send should have returned") } }) } func TestSendQueueSendProbe(t *testing.T) { mockCtrl := gomock.NewController(t) c := NewMockSendConn(mockCtrl) q := newSendQueue(c) addr := &net.UDPAddr{IP: net.IPv4(42, 42, 42, 42), Port: 42} c.EXPECT().WriteTo([]byte("foobar"), addr) q.SendProbe(getPacketWithContents([]byte("foobar")), addr) } quic-go-0.59.0/send_stream.go000066400000000000000000000557471513066070600160320ustar00rootroot00000000000000package quic import ( "context" "fmt" "sync" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) // A SendStream is a unidirectional Send Stream. type SendStream struct { mutex sync.Mutex numOutstandingFrames int64 // outstanding STREAM and RESET_STREAM frames retransmissionQueue []*wire.StreamFrame ctx context.Context ctxCancel context.CancelCauseFunc streamID protocol.StreamID sender streamSender // reliableSize is the portion of the stream that needs to be transmitted reliably, // even if the stream is cancelled. // This requires the peer to support RESET_STREAM_AT. // This value should not be accessed directly, but only through the reliableOffset method. // This method returns 0 if the peer doesn't support the RESET_STREAM_AT extension. reliableSize protocol.ByteCount writeOffset protocol.ByteCount shutdownErr error resetErr *StreamError queuedResetStreamFrame *wire.ResetStreamFrame supportsResetStreamAt bool finishedWriting bool // set once Close() is called finSent bool // set when a STREAM_FRAME with FIN bit has been sent // Set when the application knows about the cancellation. // This can happen because the application called CancelWrite, // or because Write returned the error (for remote cancellations). cancellationFlagged bool completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out nextFrame *wire.StreamFrame writeChan chan struct{} writeOnce chan struct{} deadline monotime.Time flowController flowcontrol.StreamFlowController } var ( _ streamControlFrameGetter = &SendStream{} _ outgoingStream = &SendStream{} _ sendStreamFrameHandler = &SendStream{} ) func newSendStream( ctx context.Context, streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, supportsResetStreamAt bool, ) *SendStream { s := &SendStream{ streamID: streamID, sender: sender, flowController: flowController, writeChan: make(chan struct{}, 1), writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write supportsResetStreamAt: supportsResetStreamAt, } s.ctx, s.ctxCancel = context.WithCancelCause(ctx) return s } // StreamID returns the stream ID. func (s *SendStream) StreamID() StreamID { return s.streamID // same for receiveStream and sendStream } // Write writes data to the stream. // Write can be made to time out using [SendStream.SetWriteDeadline]. // If the stream was canceled, the error is a [StreamError]. func (s *SendStream) Write(p []byte) (int, error) { // Concurrent use of Write is not permitted (and doesn't make any sense), // but sometimes people do it anyway. // Make sure that we only execute one call at any given time to avoid hard to debug failures. s.writeOnce <- struct{}{} defer func() { <-s.writeOnce }() isNewlyCompleted, n, err := s.write(p) if isNewlyCompleted { s.sender.onStreamCompleted(s.streamID) } return n, err } func (s *SendStream) write(p []byte) (bool /* is newly completed */, int, error) { s.mutex.Lock() defer s.mutex.Unlock() if s.resetErr != nil { s.cancellationFlagged = true return s.isNewlyCompleted(), 0, s.resetErr } if s.shutdownErr != nil { return false, 0, s.shutdownErr } if s.finishedWriting { return false, 0, fmt.Errorf("write on closed stream %d", s.streamID) } if !s.deadline.IsZero() && !monotime.Now().Before(s.deadline) { return false, 0, errDeadline } if len(p) == 0 { return false, 0, nil } s.dataForWriting = p var ( deadlineTimer *time.Timer bytesWritten int notifiedSender bool ) for { var copied bool var deadline monotime.Time // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), // which can then be popped the next time we assemble a packet. // This allows us to return Write() when all data but x bytes have been sent out. // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 { if s.nextFrame == nil { f := wire.GetStreamFrame() f.Offset = s.writeOffset f.StreamID = s.streamID f.DataLenPresent = true f.Data = f.Data[:len(s.dataForWriting)] copy(f.Data, s.dataForWriting) s.nextFrame = f } else { l := len(s.nextFrame.Data) s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)] copy(s.nextFrame.Data[l:], s.dataForWriting) } s.dataForWriting = nil bytesWritten = len(p) copied = true } else { bytesWritten = len(p) - len(s.dataForWriting) deadline = s.deadline if !deadline.IsZero() { if !monotime.Now().Before(deadline) { s.dataForWriting = nil return false, bytesWritten, errDeadline } if deadlineTimer == nil { deadlineTimer = time.NewTimer(monotime.Until(deadline)) defer deadlineTimer.Stop() } else { deadlineTimer.Reset(monotime.Until(deadline)) } } if s.dataForWriting == nil || s.shutdownErr != nil || s.resetErr != nil { break } } s.mutex.Unlock() if !notifiedSender { s.sender.onHasStreamData(s.streamID, s) // must be called without holding the mutex notifiedSender = true } if copied { s.mutex.Lock() break } if deadline.IsZero() { <-s.writeChan } else { select { case <-s.writeChan: case <-deadlineTimer.C: } } s.mutex.Lock() } if bytesWritten == len(p) { return false, bytesWritten, nil } if s.shutdownErr != nil { return false, bytesWritten, s.shutdownErr } if s.resetErr != nil { s.cancellationFlagged = true return s.isNewlyCompleted(), bytesWritten, s.resetErr } return false, bytesWritten, nil } func (s *SendStream) canBufferStreamFrame() bool { var l protocol.ByteCount if s.nextFrame != nil { l = s.nextFrame.DataLen() } return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize } // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. func (s *SendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) { s.mutex.Lock() f, blocked, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { s.numOutstandingFrames++ } s.mutex.Unlock() if f == nil { return ackhandler.StreamFrame{}, blocked, hasMoreData } return ackhandler.StreamFrame{ Frame: f, Handler: (*sendStreamAckHandler)(s), }, blocked, hasMoreData } func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) { if s.shutdownErr != nil { return nil, nil, false } if s.resetErr != nil { reliableOffset := s.reliableOffset() if reliableOffset == 0 || (s.writeOffset >= reliableOffset && len(s.retransmissionQueue) == 0) { return nil, nil, false } } if len(s.retransmissionQueue) > 0 { f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v) if f != nil || hasMoreRetransmissions { if f == nil { return nil, nil, true } // We always claim that we have more data to send. // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. return f, nil, true } } if len(s.dataForWriting) == 0 && s.nextFrame == nil { if s.finishedWriting && !s.finSent { s.finSent = true return &wire.StreamFrame{ StreamID: s.streamID, Offset: s.writeOffset, DataLenPresent: true, Fin: true, }, nil, false } return nil, nil, false } maxDataLen := s.flowController.SendWindowSize() if maxDataLen == 0 { return nil, nil, true } // if the stream is canceled, only data up to the reliable size needs to be sent reliableOffset := s.reliableOffset() if s.resetErr != nil && reliableOffset > 0 { maxDataLen = min(maxDataLen, reliableOffset-s.writeOffset) } f, hasMoreData := s.popNewStreamFrame(maxBytes, maxDataLen, v) if f == nil { return nil, nil, hasMoreData } if f.DataLen() > 0 { s.writeOffset += f.DataLen() s.flowController.AddBytesSent(f.DataLen()) } if s.resetErr != nil && s.writeOffset >= reliableOffset { hasMoreData = false } var blocked *wire.StreamDataBlockedFrame // If the entire send window is used, the stream might have become blocked on stream-level flow control. // This is not guaranteed though, because the stream might also have been blocked on connection-level flow control. if f.DataLen() == maxDataLen && s.flowController.IsNewlyBlocked() { blocked = &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset} } f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent if f.Fin { s.finSent = true } return f, blocked, hasMoreData } // popNewStreamFrame returns a new STREAM frame to send for this stream // hasMoreData says if there's more data to send, *not* taking into account the reliable size func (s *SendStream) popNewStreamFrame(maxBytes, maxDataLen protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, hasMoreData bool) { if s.nextFrame != nil { maxDataLen := min(maxDataLen, s.nextFrame.MaxDataLen(maxBytes, v)) if maxDataLen == 0 { return nil, true } nextFrame := s.nextFrame s.nextFrame = nil if nextFrame.DataLen() > maxDataLen { s.nextFrame = wire.GetStreamFrame() s.nextFrame.StreamID = s.streamID s.nextFrame.Offset = s.writeOffset + maxDataLen s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen] s.nextFrame.DataLenPresent = true copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:]) nextFrame.Data = nextFrame.Data[:maxDataLen] } else { s.signalWrite() } return nextFrame, s.nextFrame != nil || s.dataForWriting != nil } f := wire.GetStreamFrame() f.Fin = false f.StreamID = s.streamID f.Offset = s.writeOffset f.DataLenPresent = true f.Data = f.Data[:0] hasMoreData = s.popNewStreamFrameWithoutBuffer(f, maxBytes, maxDataLen, v) if len(f.Data) == 0 && !f.Fin { f.PutBack() return nil, hasMoreData } return f, hasMoreData } func (s *SendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool { maxDataLen := f.MaxDataLen(maxBytes, v) if maxDataLen == 0 { // a STREAM frame must have at least one byte of data return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } s.getDataForWriting(f, min(maxDataLen, sendWindow)) return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } func (s *SendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) { f := s.retransmissionQueue[0] newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v) if needsSplit { return newFrame, true } s.retransmissionQueue = s.retransmissionQueue[1:] return f, len(s.retransmissionQueue) > 0 } func (s *SendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) { if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes { f.Data = f.Data[:len(s.dataForWriting)] copy(f.Data, s.dataForWriting) s.dataForWriting = nil s.signalWrite() return } f.Data = f.Data[:maxBytes] copy(f.Data, s.dataForWriting) s.dataForWriting = s.dataForWriting[maxBytes:] if s.canBufferStreamFrame() { s.signalWrite() } } func (s *SendStream) isNewlyCompleted() bool { if s.completed { return false } if s.nextFrame != nil && s.nextFrame.DataLen() > 0 { return false } // We need to keep the stream around until all frames have been sent and acknowledged. if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil { return false } // The stream is completed if we sent the FIN. if s.finSent { s.completed = true return true } // The stream is also completed if: // 1. the application called CancelWrite, or // 2. we received a STOP_SENDING, and // * the application consumed the error via Write, or // * the application called Close if s.resetErr != nil && (s.cancellationFlagged || s.finishedWriting) { s.completed = true return true } return false } // Close closes the write-direction of the stream. // Future calls to Write are not permitted after calling Close. // It must not be called concurrently with Write. // It must not be called after calling CancelWrite. func (s *SendStream) Close() error { s.mutex.Lock() if s.shutdownErr != nil || s.finishedWriting { s.mutex.Unlock() return nil } s.finishedWriting = true cancelled := s.resetErr != nil if cancelled { s.cancellationFlagged = true } completed := s.isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } if cancelled { return fmt.Errorf("close called for canceled stream %d", s.streamID) } s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex s.ctxCancel(nil) return nil } // SetReliableBoundary marks the data written to this stream so far as reliable. // It is valid to call this function multiple times, thereby increasing the reliable size. // It only has an effect if the peer enabled support for the RESET_STREAM_AT extension, // otherwise, it is a no-op. func (s *SendStream) SetReliableBoundary() { s.mutex.Lock() defer s.mutex.Unlock() s.reliableSize = s.writeOffset if s.nextFrame != nil { s.reliableSize += s.nextFrame.DataLen() } } // returnFramesToPool returns all queued frames to the sync.Pool func (s *SendStream) returnFramesToPool() { for _, f := range s.retransmissionQueue { f.PutBack() } clear(s.retransmissionQueue) s.retransmissionQueue = nil if s.nextFrame != nil { s.nextFrame.PutBack() s.nextFrame = nil } } // CancelWrite aborts sending on this stream. // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. // Write will unblock immediately, and future calls to Write will fail. // When called multiple times it is a no-op. // When called after Close, it aborts reliable delivery of outstanding stream data. // Note that there is no guarantee if the peer will receive the FIN or the cancellation error first. func (s *SendStream) CancelWrite(errorCode StreamErrorCode) { s.mutex.Lock() if s.shutdownErr != nil { s.mutex.Unlock() return } s.cancellationFlagged = true if s.resetErr != nil { completed := s.isNewlyCompleted() s.mutex.Unlock() // The user has called CancelWrite. If the previous cancellation was because of a // STOP_SENDING, we don't need to flag the error to the user anymore. if completed { s.sender.onStreamCompleted(s.streamID) } return } s.resetErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} s.ctxCancel(s.resetErr) reliableOffset := s.reliableOffset() if reliableOffset == 0 { s.numOutstandingFrames = 0 s.returnFramesToPool() } s.queuedResetStreamFrame = &wire.ResetStreamFrame{ StreamID: s.streamID, FinalSize: max(s.writeOffset, reliableOffset), ErrorCode: errorCode, // if the peer doesn't support the extension, the reliable offset will always be 0 ReliableSize: reliableOffset, } if reliableOffset > 0 { if s.nextFrame != nil { if s.nextFrame.Offset >= reliableOffset { s.nextFrame.PutBack() s.nextFrame = nil } else if s.nextFrame.Offset+s.nextFrame.DataLen() > reliableOffset { s.nextFrame.Data = s.nextFrame.Data[:reliableOffset-s.nextFrame.Offset] } } if len(s.retransmissionQueue) > 0 { retransmissionQueue := make([]*wire.StreamFrame, 0, len(s.retransmissionQueue)) for _, f := range s.retransmissionQueue { if f.Offset >= reliableOffset { f.PutBack() continue } if f.Offset+f.DataLen() <= reliableOffset { retransmissionQueue = append(retransmissionQueue, f) } else { f.Data = f.Data[:reliableOffset-f.Offset] retransmissionQueue = append(retransmissionQueue, f) } } s.retransmissionQueue = retransmissionQueue } } s.mutex.Unlock() s.signalWrite() s.sender.onHasStreamControlFrame(s.streamID, s) } func (s *SendStream) enableResetStreamAt() { s.mutex.Lock() s.supportsResetStreamAt = true s.mutex.Unlock() } func (s *SendStream) updateSendWindow(limit protocol.ByteCount) { updated := s.flowController.UpdateSendWindow(limit) if !updated { // duplicate or reordered MAX_STREAM_DATA frame return } s.mutex.Lock() hasStreamData := s.dataForWriting != nil || s.nextFrame != nil s.mutex.Unlock() if hasStreamData { s.sender.onHasStreamData(s.streamID, s) } } func (s *SendStream) handleStopSendingFrame(f *wire.StopSendingFrame) { s.mutex.Lock() if s.shutdownErr != nil { s.mutex.Unlock() return } // If the stream was already cancelled (either locally, or due to a previous STOP_SENDING frame), // there's nothing else to do. if s.resetErr != nil && s.reliableOffset() == 0 { s.mutex.Unlock() return } // if the peer stopped reading from the stream, there's no need to transmit any data reliably s.reliableSize = 0 s.numOutstandingFrames = 0 s.returnFramesToPool() if s.resetErr == nil { s.resetErr = &StreamError{StreamID: s.streamID, ErrorCode: f.ErrorCode, Remote: true} s.ctxCancel(s.resetErr) } s.queuedResetStreamFrame = &wire.ResetStreamFrame{ StreamID: s.streamID, FinalSize: s.writeOffset, ErrorCode: s.resetErr.ErrorCode, } s.mutex.Unlock() s.signalWrite() s.sender.onHasStreamControlFrame(s.streamID, s) } func (s *SendStream) getControlFrame(monotime.Time) (_ ackhandler.Frame, ok, hasMore bool) { s.mutex.Lock() defer s.mutex.Unlock() if s.queuedResetStreamFrame == nil { return ackhandler.Frame{}, false, false } s.numOutstandingFrames++ f := ackhandler.Frame{ Frame: s.queuedResetStreamFrame, Handler: (*sendStreamResetStreamHandler)(s), } s.queuedResetStreamFrame = nil return f, true, false } func (s *SendStream) reliableOffset() protocol.ByteCount { if !s.supportsResetStreamAt { return 0 } return s.reliableSize } // The Context is canceled as soon as the write-side of the stream is closed. // This happens when Close() or CancelWrite() is called, or when the peer // cancels the read-side of their stream. // The cancellation cause is set to the error that caused the stream to // close, or `context.Canceled` in case the stream is closed without error. func (s *SendStream) Context() context.Context { return s.ctx } // SetWriteDeadline sets the deadline for future Write calls // and any currently-blocked Write call. // Even if write times out, it may return n > 0, indicating that // some data was successfully written. // A zero value for t means Write will not time out. func (s *SendStream) SetWriteDeadline(t time.Time) error { s.mutex.Lock() s.deadline = monotime.FromTime(t) s.mutex.Unlock() s.signalWrite() return nil } // CloseForShutdown closes a stream abruptly. // It makes Write unblock (and return the error) immediately. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. func (s *SendStream) closeForShutdown(err error) { s.mutex.Lock() if s.shutdownErr == nil && !s.finishedWriting { s.shutdownErr = err s.returnFramesToPool() } s.mutex.Unlock() s.signalWrite() } // signalWrite performs a non-blocking send on the writeChan func (s *SendStream) signalWrite() { select { case s.writeChan <- struct{}{}: default: } } type sendStreamAckHandler SendStream var _ ackhandler.FrameHandler = &sendStreamAckHandler{} func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.PutBack() s.mutex.Lock() if s.resetErr != nil && (*SendStream)(s).reliableOffset() == 0 { s.mutex.Unlock() return } s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") } completed := (*SendStream)(s).isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } } func (s *sendStreamAckHandler) OnLost(f wire.Frame) { sf := f.(*wire.StreamFrame) s.mutex.Lock() // If the reliable size was 0 when the stream was cancelled, // the number of outstanding frames was immediately set to 0, and the retransmission queue was dropped. if s.resetErr != nil && (*SendStream)(s).reliableOffset() == 0 { // Return the frame to pool since it won't be retransmitted sf.PutBack() s.mutex.Unlock() return } s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") } if s.resetErr != nil && (*SendStream)(s).reliableOffset() > 0 { // If the stream was reset, and this frame is beyond the reliable offset, // it doesn't need to be retransmitted. if sf.Offset >= (*SendStream)(s).reliableOffset() { sf.PutBack() // If this frame was the last one tracked, losing it might cause the stream to be completed. completed := (*SendStream)(s).isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } return } // If the payload of the frame extends beyond the reliable size, // truncate the frame to the reliable size. if sf.Offset+sf.DataLen() > (*SendStream)(s).reliableOffset() { sf.Data = sf.Data[:(*SendStream)(s).reliableOffset()-sf.Offset] } } sf.DataLenPresent = true s.retransmissionQueue = append(s.retransmissionQueue, sf) s.mutex.Unlock() s.sender.onHasStreamData(s.streamID, (*SendStream)(s)) } type sendStreamResetStreamHandler SendStream var _ ackhandler.FrameHandler = &sendStreamResetStreamHandler{} func (s *sendStreamResetStreamHandler) OnAcked(f wire.Frame) { rsf := f.(*wire.ResetStreamFrame) s.mutex.Lock() // If the peer sent a STOP_SENDING after we sent a RESET_STREAM_AT frame, // we sent 1. reduced the reliable size to 0 and 2. sent a RESET_STREAM frame. // In this case, we don't care about the acknowledgment of this frame. if rsf.ReliableSize != (*SendStream)(s).reliableOffset() { s.mutex.Unlock() return } s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") } completed := (*SendStream)(s).isNewlyCompleted() s.mutex.Unlock() if completed { s.sender.onStreamCompleted(s.streamID) } } func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) { rsf := f.(*wire.ResetStreamFrame) s.mutex.Lock() // If the peer sent a STOP_SENDING after we sent a RESET_STREAM_AT frame, // we sent 1. reduced the reliable size to 0 and 2. sent a RESET_STREAM frame. // In this case, the loss of the RESET_STREAM_AT frame can be ignored. if rsf.ReliableSize != (*SendStream)(s).reliableOffset() { s.mutex.Unlock() return } s.queuedResetStreamFrame = rsf s.numOutstandingFrames-- s.mutex.Unlock() s.sender.onHasStreamControlFrame(s.streamID, (*SendStream)(s)) } quic-go-0.59.0/send_stream_test.go000066400000000000000000001637021513066070600170600ustar00rootroot00000000000000package quic import ( "bytes" "context" "crypto/rand" "errors" "fmt" "io" mrand "math/rand/v2" "net" "os" "runtime" "slices" "testing" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type writerWithTimeout struct { io.Writer Timeout time.Duration } func (w *writerWithTimeout) Write(p []byte) (n int, err error) { done := make(chan struct{}) go func() { defer close(done) n, err = w.Writer.Write(p) }() select { case <-done: return n, err case <-time.After(w.Timeout): return 0, fmt.Errorf("write timeout after %s", w.Timeout) } } func expectedFrameHeaderLen(strID protocol.StreamID, offset protocol.ByteCount) protocol.ByteCount { return (&wire.StreamFrame{StreamID: strID, Offset: offset, DataLenPresent: true}).Length(protocol.Version1) } func TestSendStreamSetup(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) ctx := context.WithValue(context.Background(), "foo", "bar") str := newSendStream(ctx, 1337, nil, mockFC, false) require.NotNil(t, str.Context()) require.Equal(t, "bar", str.Context().Value("foo")) require.Equal(t, protocol.StreamID(1337), str.StreamID()) } func TestSendStreamWriteData(t *testing.T) { const streamID protocol.StreamID = 42 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str) n, err := strWithTimeout.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) frame, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.False(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar"), DataLenPresent: true}, frame.Frame, ) require.True(t, mockCtrl.Satisfied()) // nothing more to send at this point _, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // nil writes don't do anything n, err = strWithTimeout.Write(nil) require.NoError(t, err) require.Zero(t, n) require.True(t, mockCtrl.Satisfied()) // empty slices writes don't do anything n, err = strWithTimeout.Write([]byte{}) require.NoError(t, err) require.Zero(t, n) require.True(t, mockCtrl.Satisfied()) // multiple writes are bundled into a single frame mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) n, err = strWithTimeout.Write([]byte{0xde, 0xad}) require.NoError(t, err) require.Equal(t, 2, n) n, err = strWithTimeout.Write([]byte{0xbe, 0xef}) require.NoError(t, err) require.Equal(t, 2, n) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.False(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 42, Offset: 6, Data: []byte{0xde, 0xad, 0xbe, 0xef}, DataLenPresent: true}, frame.Frame, ) // a single write is split up into smaller frames mockSender.EXPECT().onHasStreamData(streamID, str) n, err = strWithTimeout.Write([]byte("foobaz")) require.NoError(t, err) require.Equal(t, 6, n) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(3) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) frame, _, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 10), protocol.Version1) require.Nil(t, frame.Frame) require.True(t, hasMore) frame, _, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 10)+3, protocol.Version1) require.True(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 10, Data: []byte("foo"), DataLenPresent: true}, frame.Frame, ) frame, _, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 13)+3, protocol.Version1) require.False(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 13, Data: []byte("baz"), DataLenPresent: true}, frame.Frame, ) } func TestSendStreamLargeWrites(t *testing.T) { synctest.Test(t, func(t *testing.T) { const streamID protocol.StreamID = 1337 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) data := make([]byte, 5000) rand.Read(data) errChan := make(chan error, 1) go func() { _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(data) str.Close() errChan <- err }() synctest.Wait() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxPacketBufferSize).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() var offset protocol.ByteCount const size = 40 for offset+size < protocol.ByteCount(len(data))-protocol.MaxPacketBufferSize { frame, _, hasMore := str.popStreamFrame(size+expectedFrameHeaderLen(streamID, offset), protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMore) require.Equal(t, offset, frame.Frame.Offset) require.Equal(t, data[offset:offset+size], frame.Frame.Data) offset += size require.True(t, mockCtrl.Satisfied()) } // Write should still be blocked, since there's more than protocol.MaxPacketBufferSize left to send select { case err := <-errChan: require.NoError(t, err) default: } // empty frames are not sent frame, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, offset), protocol.Version1) require.Nil(t, frame.Frame) require.True(t, hasMore) mockSender.EXPECT().onHasStreamData(streamID, str) // from the Close call frame, _, hasMore = str.popStreamFrame(size+expectedFrameHeaderLen(streamID, offset), protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMore) require.Equal(t, data[offset:offset+size], frame.Frame.Data) require.Equal(t, offset, frame.Frame.Offset) offset += size synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("write should have returned") } frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.NotNil(t, frame.Frame) require.False(t, hasMore) require.Equal(t, data[offset:], frame.Frame.Data) require.True(t, frame.Frame.Fin) }) } func TestSendStreamLargeWriteBlocking(t *testing.T) { synctest.Test(t, func(t *testing.T) { const streamID protocol.StreamID = 1337 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) errChan := make(chan error, 1) go func() { _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(make([]byte, protocol.MaxPacketBufferSize)) errChan <- err }() synctest.Wait() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) frame, _, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0)+3, protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMoreData) require.Equal(t, []byte("foo"), frame.Frame.Data) synctest.Wait() select { case err := <-errChan: t.Fatalf("write should not have returned yet: %v", err) default: } mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) frame, _, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(streamID, 3)+3, protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMoreData) require.Equal(t, []byte("bar"), frame.Frame.Data) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("timeout") } }) } func TestSendStreamCopyData(t *testing.T) { const streamID protocol.StreamID = 42 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} // for small writes data := []byte("foobar") mockSender.EXPECT().onHasStreamData(streamID, str) _, err := strWithTimeout.Write(data) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) frame, _, _ := str.popStreamFrame(protocol.MaxPacketBufferSize, protocol.Version1) data[1] = 'e' // modify the data after it has been written require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar"), DataLenPresent: true}, frame.Frame, ) } func TestSendStreamDeadlineInThePast(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 42, mockSender, mockFC, false) // no data is written when the deadline is in the past require.NoError(t, str.SetWriteDeadline(time.Now().Add(-time.Second))) n, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Zero(t, n) var nerr net.Error require.ErrorAs(t, err, &nerr) require.True(t, nerr.Timeout()) // data is written when the deadline is in the future mockSender.EXPECT().onHasStreamData(gomock.Any(), str) require.NoError(t, str.SetWriteDeadline(time.Now().Add(time.Second))) n, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) } func TestSendStreamDeadlineRemoval(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 42, mockSender, mockFC, false) deadline := time.Second require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) mockSender.EXPECT().onHasStreamData(gomock.Any(), str).Times(2) // small writes are written immediately _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) // large writes might block, and therefore subject to the deadline errChan := make(chan error, 1) go func() { _, err := (&writerWithTimeout{Writer: str, Timeout: 5 * time.Second}).Write(make([]byte, 2000)) errChan <- err }() synctest.Wait() select { case err := <-errChan: t.Fatalf("write should not have returned yet: %v", err) case <-time.After(deadline / 2): } // remove the deadline after a while (but before it expires) require.NoError(t, str.SetWriteDeadline(time.Time{})) select { case err := <-errChan: t.Fatalf("write should not have returned yet: %v", err) case <-time.After(deadline): } // now set the deadline to the past to make Write return immediately require.NoError(t, str.SetWriteDeadline(time.Now().Add(-time.Second))) synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, os.ErrDeadlineExceeded) default: } mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) frame, _, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.NotNil(t, frame.Frame) require.False(t, hasMoreData) require.Equal(t, []byte("foobar"), frame.Frame.Data) }) } func TestSendStreamDeadlineExtension(t *testing.T) { synctest.Test(t, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 42, mockSender, mockFC, false) deadline := time.Minute require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) mockSender.EXPECT().onHasStreamData(gomock.Any(), str) errChan := make(chan error, 1) go func() { _, err := str.Write(make([]byte, 2000)) errChan <- err }() synctest.Wait() select { case err := <-errChan: t.Fatalf("write should not have returned yet: %v", err) case <-time.After(deadline / 2): } // extend the deadline start := time.Now() require.NoError(t, str.SetWriteDeadline(start.Add(deadline))) synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Equal(t, deadline, time.Since(start)) case <-time.After(deadline + time.Nanosecond): t.Fatal("timeout") } frame, _, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMoreData) }) } func TestSendStreamClose(t *testing.T) { const streamID protocol.StreamID = 1234 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) _, err := strWithTimeout.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, str.Close()) select { case <-str.Context().Done(): default: t.Fatal("stream context should have been canceled") } mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) frame, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0)+3, protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 0, Data: []byte("foo"), DataLenPresent: true}, // no FIN yet frame.Frame, ) frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.False(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 3, Fin: true, Data: []byte("bar"), DataLenPresent: true}, frame.Frame, ) require.True(t, mockCtrl.Satisfied()) // further calls to Write return an error _, err = strWithTimeout.Write([]byte("foobar")) require.ErrorContains(t, err, "write on closed stream 1234") frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) // further calls to Close don't do anything require.NoError(t, str.Close()) frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // shutting down has no effect str.closeForShutdown(errors.New("goodbye")) _, err = strWithTimeout.Write([]byte("foobar")) require.ErrorContains(t, err, "write on closed stream 1234") } func TestSendStreamImmediateClose(t *testing.T) { const streamID protocol.StreamID = 1337 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) require.NoError(t, str.Close()) frame, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 13)+3, protocol.Version1) require.False(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Fin: true, DataLenPresent: true}, frame.Frame, ) } func TestSendStreamFlowControlBlocked(t *testing.T) { const streamID protocol.StreamID = 42 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := str.Write([]byte("foobar")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(0)) mockFC.EXPECT().IsNewlyBlocked().Return(true) frame, blocked, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.True(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, frame.Frame, ) require.Equal(t, &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 3}, blocked) frame, blocked, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.Nil(t, blocked) require.True(t, hasMore) _, ok, hasMore := str.getControlFrame(monotime.Now()) require.False(t, ok) require.False(t, hasMore) } func TestSendStreamCloseForShutdown(t *testing.T) { synctest.Test(t, func(t *testing.T) { const streamID protocol.StreamID = 1337 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str) errChan := make(chan error, 1) go func() { _, err := strWithTimeout.Write(bytes.Repeat([]byte("foobar"), 1000)) errChan <- err }() synctest.Wait() str.closeForShutdown(assert.AnError) synctest.Wait() require.True(t, mockCtrl.Satisfied()) select { case err := <-errChan: require.ErrorIs(t, err, assert.AnError) default: } // STOP_SENDING frames are ignored str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) _, ok, hasMore := str.getControlFrame(monotime.Now()) require.False(t, ok) require.False(t, hasMore) // future calls to Write should return the error _, err := strWithTimeout.Write([]byte("foobar")) require.ErrorIs(t, err, assert.AnError) // closing the stream doesn't do anything require.NoError(t, str.Close()) // no STREAM frames popped frame, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) // canceling the stream doesn't do anything str.CancelWrite(1234) _, err = strWithTimeout.Write([]byte("foobar")) require.ErrorIs(t, err, assert.AnError) // error unchanged }) } func TestSendStreamUpdateSendWindow(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 42, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(gomock.Any(), str) _, err := str.Write([]byte("foobar")) require.NoError(t, err) require.True(t, mockCtrl.Satisfied()) // no calls to onHasStreamData if the window size wasn't increased mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(41)).Return(false) str.updateSendWindow(41) gomock.InOrder( mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(123)).Return(true), mockSender.EXPECT().onHasStreamData(protocol.StreamID(42), str), ) str.updateSendWindow(123) } func TestSendStreamCancellation(t *testing.T) { synctest.Test(t, func(t *testing.T) { const streamID protocol.StreamID = 42 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str) _, err := strWithTimeout.Write([]byte("foobar")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) frame, _, hasMore := str.popStreamFrame(3+expectedFrameHeaderLen(streamID, 0), protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMore) require.Equal(t, []byte("foo"), frame.Frame.Data) require.True(t, mockCtrl.Satisfied()) // The stream doesn't support RESET_STREAM_AT. // Setting the reliable boundary has no effect. str.SetReliableBoundary() wrote := make(chan struct{}) mockSender.EXPECT().onHasStreamData(streamID, str).Do(func(protocol.StreamID, *SendStream) { close(wrote) }) errChan := make(chan error, 1) go func() { _, err := strWithTimeout.Write(make([]byte, 2000)) errChan <- err }() synctest.Wait() // cancel the stream mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.CancelWrite(1234) require.True(t, mockCtrl.Satisfied()) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) // only the "foo" was sent out, so the final size is 3 require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 3, ErrorCode: 1234}, cf.Frame) require.False(t, hasMore) // the context was canceled select { case <-str.Context().Done(): default: t.Fatal("stream context should have been canceled") } require.ErrorIs(t, context.Cause(str.Context()), &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) // duplicate calls to CancelWrite don't do anything str.CancelWrite(1234) _, ok, _ = str.getControlFrame(monotime.Now()) require.False(t, ok) synctest.Wait() // the Write call should return an error select { case err := <-errChan: require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) default: t.Fatal("write should have returned") } // no data to send frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) // future calls to Write should return an error _, err = strWithTimeout.Write([]byte("foo")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) frame, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) // Close has no effect require.ErrorContains(t, str.Close(), "close called for canceled stream") frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) _, err = strWithTimeout.Write([]byte("foobar")) require.Error(t, err) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) // shutting down has no effect str.closeForShutdown(errors.New("goodbyte")) _, err = strWithTimeout.Write([]byte("foobar")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false}) }) } // It is possible to cancel a stream after it has been closed. // This is useful if the applications wants to prevent the retransmission of outstanding stream data. func TestSendStreamCancellationAfterClose(t *testing.T) { const streamID protocol.StreamID = 1234 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) _, err := strWithTimeout.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, str.Close()) mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.CancelWrite(1337) frame, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame) require.False(t, hasMore) _, err = strWithTimeout.Write([]byte("foobar")) require.Error(t, err) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: false}) } func TestSendStreamCancellationStreamRetransmission(t *testing.T) { t.Run("local", func(t *testing.T) { testSendStreamCancellationStreamRetransmission(t, false) }) t.Run("remote", func(t *testing.T) { testSendStreamCancellationStreamRetransmission(t, true) }) } func testSendStreamCancellationStreamRetransmission(t *testing.T, remote bool) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) f1, _, hasMore := str.popStreamFrame(3+expectedFrameHeaderLen(streamID, 0), protocol.Version1) require.NotNil(t, f1.Frame) require.True(t, hasMore) f2, _, hasMore := str.popStreamFrame(3+expectedFrameHeaderLen(streamID, 3), protocol.Version1) require.NotNil(t, f2.Frame) require.False(t, hasMore) mockSender.EXPECT().onHasStreamControlFrame(streamID, str) if remote { str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) } else { str.CancelWrite(1337) } cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.IsType(t, &wire.ResetStreamFrame{}, cf.Frame) require.False(t, hasMore) // it doesn't matter if the STREAM frames are acked or lost f1.Handler.OnAcked(f1.Frame) f2.Handler.OnLost(f2.Frame) frame, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) require.False(t, hasMore) // if CancelWrite was called, the stream is completed as soon as the RESET_STREAM frame is acked if !remote { mockSender.EXPECT().onStreamCompleted(streamID) } cf.Handler.OnAcked(cf.Frame) // but if it's a remote cancellation, the application has to consume the error first if remote { mockSender.EXPECT().onStreamCompleted(streamID) _, err := str.Write([]byte("foobar")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) } } func TestSendStreamCancellationResetStreamRetransmission(t *testing.T) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.CancelWrite(1337) f1, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, f1.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // lose the RESET_STREAM frame mockSender.EXPECT().onHasStreamControlFrame(streamID, str) f1.Handler.OnLost(f1.Frame) // get the retransmission f2, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, f2.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // acknowledging the RESET_STREAM frame completes the stream mockSender.EXPECT().onStreamCompleted(streamID) f2.Handler.OnAcked(f2.Frame) } func TestSendStreamStopSendingAfterWrite(t *testing.T) { t.Run("complete by Write", func(t *testing.T) { testSendStreamStopSendingAfterWrite(t, "write") }) t.Run("complete by Close", func(t *testing.T) { testSendStreamStopSendingAfterWrite(t, "close") }) t.Run("complete by CancelWrite", func(t *testing.T) { testSendStreamStopSendingAfterWrite(t, "cancelwrite") }) } func testSendStreamStopSendingAfterWrite(t *testing.T, completeBy string) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(2) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) frame, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, mockCtrl.Satisfied()) mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 6, ErrorCode: 1337}, cf.Frame) require.False(t, hasMore) // acknowledging the RESET_STREAM frame doesn't complete the stream, // since it was neither cancelled nor closed cf.Handler.OnAcked(cf.Frame) require.True(t, mockCtrl.Satisfied()) mockSender.EXPECT().onStreamCompleted(streamID) switch completeBy { case "write": // calls to Write should return an error _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) case "close": require.ErrorContains(t, str.Close(), "close called for canceled stream") case "cancelwrite": str.CancelWrite(1234) } // error code and remote flag are unchanged _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) _, ok, _ = str.getControlFrame(monotime.Now()) require.False(t, ok) } func TestSendStreamStopSendingDuringWrite(t *testing.T) { synctest.Test(t, func(t *testing.T) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(2) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) frame, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, mockCtrl.Satisfied()) errChan := make(chan error, 1) go func() { _, err := str.Write(make([]byte, 2000)) errChan <- err }() mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) default: t.Fatal("write should have returned") } cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 6, ErrorCode: 1337}, cf.Frame) require.False(t, hasMore) // receiving another STOP_SENDING frame has no effect str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1234}) _, ok, hasMore = str.getControlFrame(monotime.Now()) require.False(t, ok) require.False(t, hasMore) // acknowledging the RESET_STREAM frame completes the stream mockSender.EXPECT().onStreamCompleted(streamID) cf.Handler.OnAcked(cf.Frame) require.True(t, mockCtrl.Satisfied()) // calls to Write should return an error _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) // calls to CancelWrite have no effect str.CancelWrite(1234) _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) // error code and remote flag are unchanged require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) _, ok, _ = str.getControlFrame(monotime.Now()) require.False(t, ok) // Close has no effect require.ErrorContains(t, str.Close(), "close called for canceled stream") frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, frame.Frame) _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.Error(t, err) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) }) } // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). // A single successful run of this test therefore doesn't mean a lot, // for reliable results it has to be run many times. func TestSendStreamConcurrentWriteAndCancel(t *testing.T) { synctest.Test(t, func(t *testing.T) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).MaxTimes(1) mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1) mockSender.EXPECT().onStreamCompleted(streamID).MaxTimes(1) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).MaxTimes(1) mockFC.EXPECT().AddBytesSent(gomock.Any()).MaxTimes(1) errChan := make(chan error, 1) go func() { n, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(make([]byte, 100)) if n == 0 { errChan <- nil return } errChan <- err }() done := make(chan struct{}, 2) go func() { str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) done <- struct{}{} }() go func() { str.CancelWrite(1234) done <- struct{}{} }() synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("write should have returned") } for range 2 { select { case <-done: default: t.Fatal("timeout waiting for cancel to complete") } } }) } func TestSendStreamRetransmissions(t *testing.T) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := str.Write([]byte("foo")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) f1, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, f1.Frame, ) require.True(t, mockCtrl.Satisfied()) // write some more data mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("bar")) require.NoError(t, err) require.NoError(t, str.Close()) require.True(t, mockCtrl.Satisfied()) // lose the frame mockSender.EXPECT().onHasStreamData(streamID, str) f1.Handler.OnLost(f1.Frame) require.True(t, mockCtrl.Satisfied()) // when popping a new frame, we first get the retransmission... f2, _, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, f2.Frame) require.True(t, hasMoreData) require.True(t, mockCtrl.Satisfied()) // ... then we get the new data mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) f3, _, hasMoreData := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 3, Fin: true, Data: []byte("bar"), DataLenPresent: true}, f3.Frame) require.False(t, hasMoreData) require.True(t, mockCtrl.Satisfied()) // acknowledge the retransmission... f2.Handler.OnAcked(f2.Frame) // ... and the last frame, which concludes this stream mockSender.EXPECT().onStreamCompleted(streamID) f3.Handler.OnAcked(f3.Frame) } func TestSendStreamRetransmissionFraming(t *testing.T) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) f, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.NotNil(t, f.Frame) // lose the frame mockSender.EXPECT().onHasStreamData(streamID, str) f.Handler.OnLost(f.Frame) // retransmission doesn't fit f, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0), protocol.Version1) require.Nil(t, f.Frame) require.True(t, hasMore) // split the retransmission r1, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 0)+3, protocol.Version1) require.True(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Data: []byte("foo"), DataLenPresent: true}, r1.Frame, ) r2, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 3)+3, protocol.Version1) require.True(t, hasMore) // When popping a retransmission, we always claim that there's more data to send. // We accept that this might be incorrect. require.True(t, hasMore) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: streamID, Offset: 3, Data: []byte("bar"), DataLenPresent: true}, r2.Frame, ) _, _, hasMore = str.popStreamFrame(expectedFrameHeaderLen(streamID, 3)+3, protocol.Version1) require.False(t, hasMore) } // This test is kind of an integration test. // It writes 4 MB of data, and pops STREAM frames that sometimes are and sometimes aren't limited by flow control. // Half of these STREAM frames are then received and their content saved, while the other half is reported lost // and has to be retransmitted. func TestSendStreamRetransmitDataUntilAcknowledged(t *testing.T) { const streamID protocol.StreamID = 123456 const dataLen = 1 << 22 // 4 MB mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes() mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { return protocol.ByteCount(mrand.IntN(500)) + 50 }).AnyTimes() mockFC.EXPECT().IsNewlyBlocked().Return(false).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() data := make([]byte, dataLen) _, err := rand.Read(data) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) _, err := str.Write(data) require.NoError(t, err) str.Close() }() var completed bool mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) received := make([]byte, dataLen) var counter int frameQueue := make([]ackhandler.StreamFrame, 0, 32) for !completed || len(frameQueue) > 0 { counter++ if counter > 1e6 { t.Fatal("stream should have completed") } f, _, _ := str.popStreamFrame(protocol.ByteCount(mrand.IntN(300)+100), protocol.Version1) var dequeuedFrame bool if f.Frame != nil { frameQueue = append(frameQueue, f) dequeuedFrame = true } // Process one of the queued frames at random. // This simulates potential reordering. if len(frameQueue) > 0 && (!dequeuedFrame || len(frameQueue) == cap(frameQueue)) { idx := mrand.IntN(len(frameQueue)) f := frameQueue[idx] // 50%: acknowledge the frame and save the data // 50%: lose the frame if mrand.Int()%2 == 0 { copy(received[f.Frame.Offset:f.Frame.Offset+f.Frame.DataLen()], f.Frame.Data) f.Handler.OnAcked(f.Frame) } else { f.Handler.OnLost(f.Frame) } frameQueue = slices.Delete(frameQueue, idx, idx+1) } runtime.Gosched() } require.Equal(t, data, received) } func TestSendStreamResetStreamAtCancelBeforeSend(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) _, err := str.Write([]byte("foobar")) require.NoError(t, err) str.SetReliableBoundary() _, err = str.Write([]byte("baz")) require.NoError(t, err) mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) str.CancelWrite(1337) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 6, ErrorCode: 1337, ReliableSize: 6}, cf.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) mockFC.EXPECT().IsNewlyBlocked() f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, f.Frame, ) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // Lose the frame. // Since it's before the reliable size, we should get a retransmission. mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str) f.Handler.OnLost(f.Frame) require.True(t, mockCtrl.Satisfied()) retransmission, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, retransmission.Frame, ) require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission require.True(t, mockCtrl.Satisfied()) f, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, f.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // acknowledging the RESET_STREAM_AT and the retransmission completes the stream cf.Handler.OnAcked(cf.Frame) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) retransmission.Handler.OnAcked(retransmission.Frame) } func TestSendStreamResetStreamAtCancelAfterSend(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) _, err := str.Write([]byte("foobar")) require.NoError(t, err) str.SetReliableBoundary() _, err = str.Write([]byte("baz")) require.NoError(t, err) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(9)) f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("foobarbaz"), DataLenPresent: true}, f.Frame, ) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) str.CancelWrite(42) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 6}, cf.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) cf.Handler.OnAcked(cf.Frame) // lose the STREAM frame mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str) f.Handler.OnLost(f.Frame) // only the first 6 bytes need to be retransmitted retransmission1, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, retransmission1.Frame, ) require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission require.True(t, mockCtrl.Satisfied()) f, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, f.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // lose the retransmission as well mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str) retransmission1.Handler.OnLost(retransmission1.Frame) retransmission2, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, retransmission2.Frame, ) require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission require.True(t, mockCtrl.Satisfied()) f, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, f.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // acknowledge the 2nd retransmission mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) retransmission2.Handler.OnAcked(retransmission2.Frame) } func TestSendStreamResetStreamAtRetransmissions(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) // f1: lorem // f2: ipsumdolor (reliable offset: right after the "ipsum") // f3: sit // f4: amet // sitting in the write buffer: consectetur (but not popped) mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).AnyTimes() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() _, err := str.Write([]byte("lorem")) require.NoError(t, err) f1, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("lorem"), DataLenPresent: true}, f1.Frame, ) _, err = str.Write([]byte("ipsum")) require.NoError(t, err) str.SetReliableBoundary() _, err = str.Write([]byte("dolor")) require.NoError(t, err) f2, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Offset: 5, Data: []byte("ipsumdolor"), DataLenPresent: true}, f2.Frame, ) _, err = str.Write([]byte("sit")) require.NoError(t, err) f3, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Offset: 15, Data: []byte("sit"), DataLenPresent: true}, f3.Frame, ) _, err = str.Write([]byte("amet")) require.NoError(t, err) f4, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Offset: 18, Data: []byte("amet"), DataLenPresent: true}, f4.Frame, ) _, err = str.Write([]byte("consectetur")) require.NoError(t, err) // lose the frames, in no particular order f2.Handler.OnLost(f2.Frame) f1.Handler.OnLost(f1.Frame) f3.Handler.OnLost(f3.Frame) // f4 is lost at a later point // Now cancel the stream. // We expect f1 and the first half of f2 to be retransmitted, // but f3 and the data in the buffer should not. mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) str.CancelWrite(42) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 22, ErrorCode: 42, ReliableSize: 10}, cf.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) cf.Handler.OnAcked(cf.Frame) // // the retransmission of f1 should be truncated to 6 bytes r1, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Offset: 5, Data: []byte("ipsum"), DataLenPresent: true}, r1.Frame, ) require.True(t, hasMore) r2, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.EqualExportedValues(t, &wire.StreamFrame{StreamID: 1337, Data: []byte("lorem"), DataLenPresent: true}, r2.Frame, ) require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission require.True(t, mockCtrl.Satisfied()) r3, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Nil(t, r3.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) r1.Handler.OnAcked(r1.Frame) r2.Handler.OnAcked(r2.Frame) require.True(t, mockCtrl.Satisfied()) // the stream is only completed once f4 is lost // it's beyond the reliable size, so it's not retransmitted mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) f4.Handler.OnLost(f4.Frame) } func TestSendStreamResetStreamAtStopSendingBeforeCancelation(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) _, err := str.Write([]byte("foobar")) require.NoError(t, err) str.SetReliableBoundary() _, err = str.Write([]byte("baz")) require.NoError(t, err) // send out a STREAM frame with all the data written so far mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(9)) f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Equal(t, protocol.ByteCount(9), f.Frame.DataLen()) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: 1337, ErrorCode: 42}) cf, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) // Since the peer reset the stream, the resulting RESET_STREAM frame has a reliable size of 0 require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 0}, cf.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // calling CancelWrite doesn't cause any more frames to be enqueued str.CancelWrite(1234) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) cf.Handler.OnAcked(cf.Frame) } func TestSendStreamResetStreamAtStopSendingAfterCancelation(t *testing.T) { t.Run("RESET_STREAM_AT lost", func(t *testing.T) { testSendStreamResetStreamAtStopSendingAfterCancelation(t, true) }) t.Run("RESET_STREAM_AT acknowledged", func(t *testing.T) { testSendStreamResetStreamAtStopSendingAfterCancelation(t, false) }) } func testSendStreamResetStreamAtStopSendingAfterCancelation(t *testing.T, loseResetStreamAt bool) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) _, err := str.Write([]byte("foobar")) require.NoError(t, err) str.SetReliableBoundary() _, err = str.Write([]byte("baz")) require.NoError(t, err) // send out a STREAM frame with all the data written so far mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(9)) f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.Equal(t, protocol.ByteCount(9), f.Frame.DataLen()) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) // Canceling the stream results in a RESET_STREAM_AT frame. mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) str.CancelWrite(42) cf1, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 6}, cf1.Frame) require.False(t, hasMore) // Receiving a STOP_SENDING frame results in a RESET_STREAM frame, // effectively reducing the reliable size to 0. mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: 1337, ErrorCode: 1234}) cf2, ok, hasMore := str.getControlFrame(monotime.Now()) require.True(t, ok) // Since the peer reset the stream, the resulting RESET_STREAM frame has a reliable size of 0. // The error code is still the one used for the CancelWrite call. require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 0}, cf2.Frame) require.False(t, hasMore) require.True(t, mockCtrl.Satisfied()) if loseResetStreamAt { // losing the RESET_STREAM_AT frame does nothing cf1.Handler.OnLost(cf1.Frame) } else { // receiving an acknowledgment for the RESET_STREAM_AT frame does nothing either: // the RESET_STREAM frame still needs to be transmitted reliably cf1.Handler.OnAcked(cf1.Frame) } _, ok, _ = str.getControlFrame(monotime.Now()) require.False(t, ok) // but when the RESET_STREAM frame is lost, it needs to be retransmitted mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) cf2.Handler.OnLost(cf2.Frame) cf3, ok, _ := str.getControlFrame(monotime.Now()) require.True(t, ok) require.Equal(t, cf2, cf3) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) cf3.Handler.OnAcked(cf3.Frame) } func TestSendStreamResetStreamAtRandomized(t *testing.T) { const streamID protocol.StreamID = 123456 const dataLen = 8 << 10 reliableOffset := 1 + mrand.IntN(dataLen*3/4) t.Logf("reliable offset: %d", reliableOffset) mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newSendStream(context.Background(), streamID, mockSender, mockFC, true) mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes() mockSender.EXPECT().onHasStreamControlFrame(streamID, str).AnyTimes() mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { return protocol.ByteCount(mrand.IntN(500)) + 50 }).AnyTimes() mockFC.EXPECT().IsNewlyBlocked().Return(false).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() data := make([]byte, dataLen) _, err := rand.Read(data) require.NoError(t, err) errChan := make(chan error, 1) go func() { b := data var offset int for len(b) > 0 { m := mrand.IntN(1024) if offset < reliableOffset { m = min(m, reliableOffset-offset) } n, err := str.Write(b[:min(m, len(b))]) if err != nil { errChan <- err return } offset += n if offset <= reliableOffset { str.SetReliableBoundary() } b = b[n:] } str.CancelWrite(1234) errChan <- nil }() var completed bool mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) received := make([]byte, dataLen) var highestOffset int var receivedResetStreamAt bool var counter int frameQueue := make([]any, 0, 10) for !completed || len(frameQueue) > 0 { counter++ if counter > 1e6 { t.Fatal("stream should have completed") } var dequeuedFrame bool cf, ok, _ := str.getControlFrame(monotime.Now()) if ok { dequeuedFrame = true frameQueue = append(frameQueue, cf) receivedResetStreamAt = true require.Equal(t, protocol.ByteCount(reliableOffset), cf.Frame.(*wire.ResetStreamFrame).ReliableSize) } else { f, _, _ := str.popStreamFrame(protocol.ByteCount(mrand.IntN(300)+100), protocol.Version1) if f.Frame != nil { // make sure that only retransmissions are sent once the RESET_STREAM_AT frame is sent if receivedResetStreamAt { require.LessOrEqualf(t, f.Frame.Offset+f.Frame.DataLen(), protocol.ByteCount(reliableOffset), "STREAM frame past reliable offset after RESET_STREAM_AT (offset: %d, data length: %d)", f.Frame.Offset, f.Frame.DataLen(), ) } dequeuedFrame = true frameQueue = append(frameQueue, f) } } if len(frameQueue) > 0 && (!dequeuedFrame || len(frameQueue) == cap(frameQueue)) { idx := mrand.IntN(len(frameQueue)) switch f := frameQueue[idx].(type) { case ackhandler.Frame: // 50%: acknowledge the frame // 50%: lose the frame if mrand.Int()%2 == 0 { f.Handler.OnLost(f.Frame) } else { f.Handler.OnAcked(f.Frame) } case ackhandler.StreamFrame: sf := f.Frame // 50%: acknowledge the frame and save the data // 50%: lose the frame if mrand.Int()%2 == 0 { f.Handler.OnLost(f.Frame) } else { highestOffset = max(highestOffset, int(sf.Offset+sf.DataLen())) copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) f.Handler.OnAcked(f.Frame) } default: t.Fatalf("unexpected frame type: %T", f) } frameQueue = slices.Delete(frameQueue, idx, idx+1) } runtime.Gosched() } t.Logf("highest received offset: %d", highestOffset) require.GreaterOrEqual(t, highestOffset, reliableOffset) require.Equal(t, data[:reliableOffset], received[:reliableOffset]) } quic-go-0.59.0/server.go000066400000000000000000001035041513066070600150150ustar00rootroot00000000000000package quic import ( "context" "crypto/tls" "errors" "fmt" "net" "sync" "time" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) // ErrServerClosed is returned by the [Listener] or [EarlyListener]'s Accept method after a call to Close. var ErrServerClosed = errServerClosed{} type errServerClosed struct{} func (errServerClosed) Error() string { return "quic: server closed" } func (errServerClosed) Unwrap() error { return net.ErrClosed } // packetHandler handles packets type packetHandler interface { handlePacket(receivedPacket) destroy(error) closeWithTransportError(qerr.TransportErrorCode) } type zeroRTTQueue struct { packets []receivedPacket expiration monotime.Time } type rejectedPacket struct { receivedPacket hdr *wire.Header } // A Listener of QUIC type baseServer struct { tr *packetHandlerMap disableVersionNegotiation bool acceptEarlyConns bool tlsConf *tls.Config config *Config conn rawConn tokenGenerator *handshake.TokenGenerator maxTokenAge time.Duration connIDGenerator ConnectionIDGenerator statelessResetter *statelessResetter onClose func() receivedPackets chan receivedPacket nextZeroRTTCleanup monotime.Time zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true connContext func(context.Context, *ClientInfo) (context.Context, error) // set as a member, so they can be set in the tests newConn func( context.Context, context.CancelCauseFunc, sendConn, connRunner, protocol.ConnectionID, /* original dest connection ID */ *protocol.ConnectionID, /* retry src connection ID */ protocol.ConnectionID, /* client dest connection ID */ protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* source connection ID */ ConnectionIDGenerator, *statelessResetter, *Config, *tls.Config, *handshake.TokenGenerator, bool, /* client address validated by an address validation token */ time.Duration, qlogwriter.Trace, utils.Logger, protocol.Version, ) *wrappedConn closeMx sync.Mutex // errorChan is closed when Close is called. This has two effects: // 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors // 2. it stops handling of packets passed to this server errorChan chan struct{} // acceptChan is closed when Close returns. // This only happens once all handshake in flight have either completed and canceled. // Calls to Accept will first drain the queue of connections that have completed the handshake, // and then return ErrServerClosed. stopAccepting chan struct{} closeErr error running chan struct{} // closed as soon as run() returns versionNegotiationQueue chan receivedPacket invalidTokenQueue chan rejectedPacket connectionRefusedQueue chan rejectedPacket retryQueue chan rejectedPacket handshakingCount sync.WaitGroup verifySourceAddress func(net.Addr) bool connQueue chan *Conn qlogger qlogwriter.Recorder logger utils.Logger } // A Listener listens for incoming QUIC connections. // It returns connections once the handshake has completed. type Listener struct { baseServer *baseServer } // Accept returns new connections. It should be called in a loop. func (l *Listener) Accept(ctx context.Context) (*Conn, error) { return l.baseServer.Accept(ctx) } // Close closes the listener. // Accept will return [ErrServerClosed] as soon as all connections in the accept queue have been accepted. // QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error. // Already established (accepted) connections will be unaffected. func (l *Listener) Close() error { return l.baseServer.Close() } // Addr returns the local network address that the server is listening on. func (l *Listener) Addr() net.Addr { return l.baseServer.Addr() } // An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes. // For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data. // This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified. // For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the // 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the // client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker. type EarlyListener struct { baseServer *baseServer } // Accept returns a new connections. It should be called in a loop. func (l *EarlyListener) Accept(ctx context.Context) (*Conn, error) { conn, err := l.baseServer.accept(ctx) if err != nil { return nil, err } return conn, nil } // Close closes the listener. // Accept will return [ErrServerClosed] as soon as all connections in the accept queue have been accepted. // Early connections that are still in flight will be rejected with a CONNECTION_REFUSED error. // Already established (accepted) connections will be unaffected. func (l *EarlyListener) Close() error { return l.baseServer.Close() } // Addr returns the local network addr that the server is listening on. func (l *EarlyListener) Addr() net.Addr { return l.baseServer.Addr() } // ListenAddr creates a QUIC server listening on a given address. // See [Listen] for more details. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { conn, err := listenUDP(addr) if err != nil { return nil, err } return (&Transport{ Conn: conn, createdConn: true, isSingleUse: true, }).Listen(tlsConf, config) } // ListenAddrEarly works like [ListenAddr], but it returns connections before the handshake completes. func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { conn, err := listenUDP(addr) if err != nil { return nil, err } return (&Transport{ Conn: conn, createdConn: true, isSingleUse: true, }).ListenEarly(tlsConf, config) } func listenUDP(addr string) (*net.UDPConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } return net.ListenUDP("udp", udpAddr) } // Listen listens for QUIC connections on a given net.PacketConn. // If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does), // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP // will be used instead of ReadFrom and WriteTo to read/write packets. // A single net.PacketConn can only be used for a single call to Listen. // // The tls.Config must not be nil and must contain a certificate configuration. // Furthermore, it must define an application control (using [NextProtos]). // The quic.Config may be nil, in that case the default values will be used. // // This is a convenience function. More advanced use cases should instantiate a [Transport], // which offers configuration options for a more fine-grained control of the connection establishment, // including reusing the underlying UDP socket for outgoing QUIC connections. // When closing a listener created with Listen, all established QUIC connections will be closed immediately. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { tr := &Transport{Conn: conn, isSingleUse: true} return tr.Listen(tlsConf, config) } // ListenEarly works like [Listen], but it returns connections before the handshake completes. func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { tr := &Transport{Conn: conn, isSingleUse: true} return tr.ListenEarly(tlsConf, config) } func newServer( conn rawConn, tr *packetHandlerMap, connIDGenerator ConnectionIDGenerator, statelessResetter *statelessResetter, connContext func(context.Context, *ClientInfo) (context.Context, error), tlsConf *tls.Config, config *Config, qlogger qlogwriter.Recorder, onClose func(), tokenGeneratorKey TokenGeneratorKey, maxTokenAge time.Duration, verifySourceAddress func(net.Addr) bool, disableVersionNegotiation bool, acceptEarly bool, ) *baseServer { s := &baseServer{ conn: conn, connContext: connContext, tr: tr, tlsConf: tlsConf, config: config, tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), maxTokenAge: maxTokenAge, verifySourceAddress: verifySourceAddress, connIDGenerator: connIDGenerator, statelessResetter: statelessResetter, connQueue: make(chan *Conn, protocol.MaxAcceptQueueSize), errorChan: make(chan struct{}), stopAccepting: make(chan struct{}), running: make(chan struct{}), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), versionNegotiationQueue: make(chan receivedPacket, 4), invalidTokenQueue: make(chan rejectedPacket, 4), connectionRefusedQueue: make(chan rejectedPacket, 4), retryQueue: make(chan rejectedPacket, 8), newConn: newConnection, qlogger: qlogger, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly, disableVersionNegotiation: disableVersionNegotiation, onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} } go s.run() go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s } func (s *baseServer) run() { defer close(s.running) for { select { case <-s.errorChan: return default: } select { case <-s.errorChan: return case p := <-s.receivedPackets: if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse { p.buffer.Release() } } } } func (s *baseServer) runSendQueue() { for { select { case <-s.running: return case p := <-s.versionNegotiationQueue: s.maybeSendVersionNegotiationPacket(p) case p := <-s.invalidTokenQueue: s.maybeSendInvalidToken(p) case p := <-s.connectionRefusedQueue: s.sendConnectionRefused(p) case p := <-s.retryQueue: s.sendRetry(p) } } } // Accept returns connections that already completed the handshake. // It is only valid if acceptEarlyConns is false. func (s *baseServer) Accept(ctx context.Context) (*Conn, error) { return s.accept(ctx) } func (s *baseServer) accept(ctx context.Context) (*Conn, error) { select { case <-ctx.Done(): return nil, ctx.Err() case conn := <-s.connQueue: return conn, nil case <-s.stopAccepting: // first drain the queue select { case conn := <-s.connQueue: return conn, nil default: } return nil, s.closeErr } } func (s *baseServer) Close() error { s.close(ErrServerClosed, false) return nil } // close closes the server. The Transport mutex must not be held while calling this method. // This method closes any handshaking connections which requires the tranpsort mutex. func (s *baseServer) close(e error, transportClose bool) { s.closeMx.Lock() if s.closeErr != nil { s.closeMx.Unlock() return } s.closeErr = e close(s.errorChan) <-s.running s.closeMx.Unlock() if !transportClose { s.onClose() } // wait until all handshakes in flight have terminated s.handshakingCount.Wait() close(s.stopAccepting) if transportClose { // if the transport is closing, drain the connQueue. All connections in the queue // will be closed by the transport. for { select { case <-s.connQueue: default: return } } } } // Addr returns the server's network address func (s *baseServer) Addr() net.Addr { return s.conn.LocalAddr() } func (s *baseServer) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: case <-s.errorChan: return default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }) } } } func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ { if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { defer s.cleanupZeroRTTQueues(p.rcvTime) } if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false } // Short header packets should never end up here in the first place if !wire.IsLongHeaderPacket(p.data[0]) { panic(fmt.Sprintf("misrouted packet: %#v", p.data)) } v, err := wire.ParseVersion(p.data) // drop the packet if we failed to parse the protocol version if err != nil { s.logger.Debugf("Dropping a packet with an unknown version") if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false } // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, v) { if s.disableVersionNegotiation { if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{Version: v}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedVersion, }) } return false } if p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{Version: v}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false } return s.enqueueVersionNegotiationPacket(p) } if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false } return s.handle0RTTPacket(p) } // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropHeaderParseError, }) } s.logger.Debugf("Error parsing packet: %s", err) return false } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, Version: v, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false } if hdr.Type != protocol.PacketTypeInitial { // Drop long header packets. // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) if s.qlogger != nil { var pt qlog.PacketType switch hdr.Type { case protocol.PacketTypeInitial: pt = qlog.PacketTypeInitial case protocol.PacketTypeHandshake: pt = qlog.PacketTypeHandshake case protocol.PacketType0RTT: pt = qlog.PacketType0RTT case protocol.PacketTypeRetry: pt = qlog.PacketTypeRetry } s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: pt, PacketNumber: protocol.InvalidPacketNumber, Version: v, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return false } s.logger.Debugf("<- Received Initial packet.") if err := s.handleInitialImpl(p, hdr); err != nil { s.logger.Errorf("Error occurred handling initial packet: %s", err) } // Don't put the packet buffer back. // handleInitialImpl deals with the buffer. return true } func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { if s.qlogger != nil { v, _ := wire.ParseVersion(p.data) s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: v, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropHeaderParseError, }) } return false } // check again if we might have a connection now if handler, ok := s.tr.Get(connID); ok { handler.handlePacket(p) return true } if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { if s.qlogger != nil { v, _ := wire.ParseVersion(p.data) s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: v, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }) } return false } q.packets = append(q.packets, p) return true } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { if s.qlogger != nil { v, _ := wire.ParseVersion(p.data) s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: v, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }) } return false } queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)} queue.packets[0] = p expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) queue.expiration = expiration if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) { s.nextZeroRTTCleanup = expiration } s.zeroRTTQueues[connID] = queue return true } func (s *baseServer) cleanupZeroRTTQueues(now monotime.Time) { // Iterate over all queues to find those that are expired. // This is ok since we're placing a pretty low limit on the number of queues. var nextCleanup monotime.Time for connID, q := range s.zeroRTTQueues { if q.expiration.After(now) { if nextCleanup.IsZero() || nextCleanup.After(q.expiration) { nextCleanup = q.expiration } continue } for _, p := range q.packets { if s.qlogger != nil { v, _ := wire.ParseVersion(p.data) s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: v, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }) } p.buffer.Release() } delete(s.zeroRTTQueues, connID) if s.logger.Debug() { s.logger.Debugf("Removing 0-RTT queue for %s.", connID) } } s.nextZeroRTTCleanup = nextCleanup } // validateToken returns false if: // - address is invalid // - token is expired // - token is null func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { if token == nil { return false } if !token.ValidateRemoteAddr(addr) { return false } if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge { return false } if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() { return false } return true } func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } p.buffer.Release() return errors.New("too short connection ID") } // The server queues packets for a while, and we might already have established a connection by now. // This results in a second check in the connection map. // That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets). if handler, ok := s.tr.Get(hdr.DestConnectionID); ok { handler.handlePacket(p) return nil } var ( token *handshake.Token retrySrcConnID *protocol.ConnectionID clientAddrVerified bool ) origDestConnID := hdr.DestConnectionID if len(hdr.Token) > 0 { tok, err := s.tokenGenerator.DecodeToken(hdr.Token) if err == nil { if tok.IsRetryToken { origDestConnID = tok.OriginalDestConnectionID retrySrcConnID = &tok.RetrySrcConnectionID } token = tok } } if token != nil { clientAddrVerified = s.validateToken(token, p.remoteAddr) if !clientAddrVerified { // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // We just ignore them, and act as if there was no token on this packet at all. // This also means we might send a Retry later. if !token.IsRetryToken { token = nil } else { // For Retry tokens, we send an INVALID_ERROR if // * the token is too old, or // * the token is invalid, in case of a retry token. select { case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: default: // drop packet if we can't send out the INVALID_TOKEN packets fast enough p.buffer.Release() } return nil } } } if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) { // Retry invalidates all 0-RTT packets sent. delete(s.zeroRTTQueues, hdr.DestConnectionID) select { case s.retryQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: default: // drop packet if we can't send out Retry packets fast enough p.buffer.Release() } return nil } // restore RTT from token var rtt time.Duration if token != nil && !token.IsRetryToken { rtt = token.RTT } config := s.config clientInfo := &ClientInfo{ RemoteAddr: p.remoteAddr, AddrVerified: clientAddrVerified, } if s.config.GetConfigForClient != nil { conf, err := s.config.GetConfigForClient(clientInfo) if err != nil { s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") s.refuseNewConn(p, hdr) return nil } config = populateConfig(conf) } var conn *wrappedConn var cancel context.CancelCauseFunc ctx, cancel1 := context.WithCancelCause(context.Background()) if s.connContext != nil { var err error ctx, err = s.connContext(ctx, clientInfo) if err != nil { cancel1(err) s.logger.Debugf("Rejecting new connection due to ConnContext callback: %s", err) s.refuseNewConn(p, hdr) return nil } if ctx == nil { panic("quic: ConnContext returned nil") } // There's no guarantee that the application returns a context // that's derived from the context we passed into ConnContext. // We need to make sure that both contexts are cancelled. var cancel2 context.CancelCauseFunc ctx, cancel2 = context.WithCancelCause(ctx) cancel = func(cause error) { cancel1(cause) cancel2(cause) } } else { cancel = cancel1 } var qlogTrace qlogwriter.Trace if config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. connID := hdr.DestConnectionID if origDestConnID.Len() > 0 { connID = origDestConnID } qlogTrace = config.Tracer(ctx, false, connID) } connID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } s.logger.Debugf("Changing connection ID to %s.", connID) conn = s.newConn( ctx, cancel, newSendConn(s.conn, p.remoteAddr, p.info, s.logger), s.tr, origDestConnID, retrySrcConnID, hdr.DestConnectionID, hdr.SrcConnectionID, connID, s.connIDGenerator, s.statelessResetter, config, s.tlsConf, s.tokenGenerator, clientAddrVerified, rtt, qlogTrace, s.logger, hdr.Version, ) conn.handlePacket(p) // Adding the connection will fail if the client's chosen Destination Connection ID is already in use. // This is very unlikely: Even if an attacker chooses a connection ID that's already in use, // under normal circumstances the packet would just be routed to that connection. // The only time this collision will occur if we receive the two Initial packets at the same time. if added := s.tr.AddWithConnID(hdr.DestConnectionID, connID, conn); !added { delete(s.zeroRTTQueues, hdr.DestConnectionID) conn.closeWithTransportError(ConnectionRefused) return nil } // Pass queued 0-RTT to the newly established connection. if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { for _, p := range q.packets { conn.handlePacket(p) } delete(s.zeroRTTQueues, hdr.DestConnectionID) } s.handshakingCount.Add(1) go func() { defer s.handshakingCount.Done() s.handleNewConn(conn) }() go conn.run() return nil } func (s *baseServer) refuseNewConn(p receivedPacket, hdr *wire.Header) { delete(s.zeroRTTQueues, hdr.DestConnectionID) select { case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: default: // drop packet if we can't send out the CONNECTION_REFUSED fast enough p.buffer.Release() } } func (s *baseServer) handleNewConn(conn *wrappedConn) { if s.acceptEarlyConns { // wait until the early connection is ready, the handshake fails, or the server is closed select { case <-s.errorChan: conn.closeWithTransportError(ConnectionRefused) return case <-conn.Context().Done(): return case <-conn.earlyConnReady(): } } else { // wait until the handshake completes, fails, or the server is closed select { case <-s.errorChan: conn.closeWithTransportError(ConnectionRefused) return case <-conn.Context().Done(): return case <-conn.HandshakeComplete(): } } select { case s.connQueue <- conn.Conn: default: conn.closeWithTransportError(ConnectionRefused) } } func (s *baseServer) sendRetry(p rejectedPacket) { if err := s.sendRetryPacket(p); err != nil { s.logger.Debugf("Error sending Retry packet: %s", err) } } func (s *baseServer) sendRetryPacket(p rejectedPacket) error { hdr := p.hdr // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) srcConnID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } token, err := s.tokenGenerator.NewRetryToken(p.remoteAddr, hdr.DestConnectionID, srcConnID) if err != nil { return err } replyHdr := &wire.ExtendedHeader{} replyHdr.Type = protocol.PacketTypeRetry replyHdr.Version = hdr.Version replyHdr.SrcConnectionID = srcConnID replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.Token = token if s.logger.Debug() { s.logger.Debugf("Changing connection ID to %s.", srcConnID) s.logger.Debugf("-> Sending Retry") replyHdr.Log(s.logger) } buf := getPacketBuffer() defer buf.Release() buf.Data, err = replyHdr.Append(buf.Data, hdr.Version) if err != nil { return err } // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketSent{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, SrcConnectionID: replyHdr.SrcConnectionID, DestConnectionID: replyHdr.DestConnectionID, Version: replyHdr.Version, Token: &qlog.Token{Raw: token}, }, Raw: qlog.RawInfo{ Length: len(buf.Data), PayloadLength: int(replyHdr.Length), }, }) } _, err = s.conn.WritePacket(buf.Data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported) return err } func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) { defer p.buffer.Release() // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. hdr := p.hdr sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) data := p.data[:hdr.ParsedLen()+hdr.Length] extHdr, err := unpackLongHeader(opener, hdr, data) // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. if err != nil { if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropHeaderParseError, }) } return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropPayloadDecryptError, }) } return } if s.logger.Debug() { s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) } if err := s.sendError(p.remoteAddr, hdr, sealer, InvalidToken, p.info); err != nil { s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) } } func (s *baseServer) sendConnectionRefused(p rejectedPacket) { defer p.buffer.Release() sealer, _ := handshake.NewInitialAEAD(p.hdr.DestConnectionID, protocol.PerspectiveServer, p.hdr.Version) if err := s.sendError(p.remoteAddr, p.hdr, sealer, ConnectionRefused, p.info); err != nil { s.logger.Debugf("Error sending CONNECTION_REFUSED error: %s", err) } } // sendError sends the error as a response to the packet received with header hdr func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error { b := getPacketBuffer() defer b.Release() ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} replyHdr := &wire.ExtendedHeader{} replyHdr.Type = protocol.PacketTypeInitial replyHdr.Version = hdr.Version replyHdr.SrcConnectionID = hdr.DestConnectionID replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.PacketNumberLen = protocol.PacketNumberLen4 replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) var err error b.Data, err = replyHdr.Append(b.Data, hdr.Version) if err != nil { return err } payloadOffset := len(b.Data) b.Data, err = ccf.Append(b.Data, hdr.Version) if err != nil { return err } _ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset]) b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()] pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) sealer.EncryptHeader( b.Data[pnOffset+4:pnOffset+4+16], &b.Data[0], b.Data[pnOffset:payloadOffset], ) replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketSent{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, SrcConnectionID: replyHdr.SrcConnectionID, DestConnectionID: replyHdr.DestConnectionID, PacketNumber: replyHdr.PacketNumber, Version: replyHdr.Version, }, Raw: qlog.RawInfo{ Length: len(b.Data), PayloadLength: int(replyHdr.Length), }, Frames: []qlog.Frame{{Frame: ccf}}, }) } _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) { select { case s.versionNegotiationQueue <- p: return true default: // it's fine to not send version negotiation packets when we are busy } return false } func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { defer p.buffer.Release() v, err := wire.ParseVersion(p.data) if err != nil { s.logger.Debugf("failed to parse version for sending version negotiation packet: %s", err) return } _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") if s.qlogger != nil { s.qlogger.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }) } return } s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) if s.qlogger != nil { s.qlogger.RecordEvent(qlog.VersionNegotiationSent{ Header: qlog.PacketHeaderVersionNegotiation{ SrcConnectionID: src, DestConnectionID: dest, }, SupportedVersions: s.config.Versions, }) } if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } quic-go-0.59.0/server_test.go000066400000000000000000001202201513066070600160460ustar00rootroot00000000000000package quic import ( "context" "crypto/rand" "crypto/tls" "errors" "net" "slices" "testing" "time" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testServer struct{ *baseServer } type serverOpts struct { eventRecorder *events.Recorder config *Config tokenGeneratorKey TokenGeneratorKey maxTokenAge time.Duration useRetry bool disableVersionNegotiation bool acceptEarly bool newConn func( context.Context, context.CancelCauseFunc, sendConn, connRunner, protocol.ConnectionID, // original dest connection ID *protocol.ConnectionID, // retry src connection ID protocol.ConnectionID, // client dest connection ID protocol.ConnectionID, // destination connection ID protocol.ConnectionID, // source connection ID ConnectionIDGenerator, *statelessResetter, *Config, *tls.Config, *handshake.TokenGenerator, bool, /* client address validated by an address validation token */ time.Duration, qlogwriter.Trace, utils.Logger, protocol.Version, ) *wrappedConn } func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer { t.Helper() c, err := wrapConn(newUDPConnLocalhost(t)) require.NoError(t, err) verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry } config := populateConfig(serverOpts.config) tr := &Transport{Conn: newUDPConnLocalhost(t)} tr.init(true) s := newServer( c, (*packetHandlerMap)(tr), &protocol.DefaultConnectionIDGenerator{}, &statelessResetter{}, func(ctx context.Context, _ *ClientInfo) (context.Context, error) { return ctx, nil }, &tls.Config{}, config, serverOpts.eventRecorder, func() {}, serverOpts.tokenGeneratorKey, serverOpts.maxTokenAge, verifySourceAddress, serverOpts.disableVersionNegotiation, serverOpts.acceptEarly, ) s.newConn = serverOpts.newConn t.Cleanup(func() { s.Close() }) return &testServer{s} } func getLongHeaderPacketEncrypted(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket { t.Helper() hdr := extHdr.Header if hdr.Type != protocol.PacketTypeInitial { t.Fatal("can only encrypt Initial packets") } p := getLongHeaderPacket(t, remoteAddr, extHdr, data) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) n := len(p.data) - len(data) // length of the header p.data = slices.Grow(p.data, 16) _ = sealer.Seal(p.data[n:n], p.data[n:], extHdr.PacketNumber, p.data[:n]) p.data = p.data[:len(p.data)+16] sealer.EncryptHeader(p.data[n:n+16], &p.data[0], p.data[n-int(extHdr.PacketNumberLen):n]) return p } func randConnID(l int) protocol.ConnectionID { b := make([]byte, l) rand.Read(b) return protocol.ParseConnectionID(b) } func getValidInitialPacket(t *testing.T, raddr net.Addr, srcConnID, destConnID protocol.ConnectionID) receivedPacket { t.Helper() return getLongHeaderPacket(t, raddr, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Length: protocol.MinInitialPacketSize, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, protocol.MinInitialPacketSize), ) } // checkConnectionClose checks // 1. the arguments of the SentPacket tracer call, and // 2. reads and parses the packet sent by the server func checkConnectionClose( t *testing.T, conn *net.UDPConn, eventRecorder *events.Recorder, expectedSrcConnID protocol.ConnectionID, expectedDestConnID protocol.ConnectionID, expectedErrorCode qerr.TransportErrorCode, ) { t.Helper() conn.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1500) n, _, err := conn.ReadFromUDP(b) require.NoError(t, err) parsedHdr, _, _, err := wire.ParsePacket(b[:n]) require.NoError(t, err) require.Equal(t, protocol.PacketTypeInitial, parsedHdr.Type) require.Equal(t, expectedSrcConnID, parsedHdr.SrcConnectionID) require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID) require.Equal(t, []qlogwriter.Event{ qlog.PacketSent{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, SrcConnectionID: expectedSrcConnID, DestConnectionID: expectedDestConnID, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: n, PayloadLength: int(parsedHdr.Length)}, Frames: []qlog.Frame{ {Frame: &qlog.ConnectionCloseFrame{ErrorCode: uint64(expectedErrorCode)}}, }, }, }, eventRecorder.Events(qlog.PacketSent{}), ) } func checkRetry(t *testing.T, conn *net.UDPConn, eventRecorder *events.Recorder, expectedDestConnID protocol.ConnectionID, ) { t.Helper() conn.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1500) n, _, err := conn.ReadFromUDP(b) require.NoError(t, err) parsedHdr, _, _, err := wire.ParsePacket(b[:n]) require.NoError(t, err) require.Equal(t, protocol.PacketTypeRetry, parsedHdr.Type) require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID) require.NotNil(t, parsedHdr.Token) require.Equal(t, []qlogwriter.Event{ qlog.PacketSent{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeRetry, DestConnectionID: expectedDestConnID, SrcConnectionID: parsedHdr.SrcConnectionID, Version: parsedHdr.Version, Token: &qlog.Token{Raw: parsedHdr.Token}, }, Raw: qlog.RawInfo{Length: n}, }, }, eventRecorder.Events(qlog.PacketSent{}), ) } func TestListen(t *testing.T) { _, err := ListenAddr("localhost:0", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "quic: tls.Config not set") _, err = Listen(nil, &tls.Config{}, &Config{Versions: []protocol.Version{0x1234}}) require.Error(t, err) require.Contains(t, err.Error(), "invalid QUIC version: 0x1234") } func TestListenAddr(t *testing.T) { _, err := ListenAddr("127.0.0.1", &tls.Config{}, &Config{}) require.Error(t, err) require.IsType(t, &net.AddrError{}, err) _, err = ListenAddr("1.1.1.1:1111", &tls.Config{}, &Config{}) require.Error(t, err) require.IsType(t, &net.OpError{}, err) ln, err := ListenAddr("127.0.0.1:0", &tls.Config{}, &Config{}) require.NoError(t, err) defer ln.Close() } func TestServerPacketDropping(t *testing.T) { t.Run("destination connection ID too short", func(t *testing.T) { conn := newUDPConnLocalhost(t) testServerDroppedPacket(t, conn, getValidInitialPacket(t, conn.LocalAddr(), randConnID(5), randConnID(7)), protocol.Version1, qlog.PacketTypeInitial, qlog.PacketDropUnexpectedPacket, ) }) t.Run("Initial packet too small", func(t *testing.T) { conn := newUDPConnLocalhost(t) p := getLongHeaderPacket(t, conn.LocalAddr(), &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: randConnID(8), Version: protocol.Version1, }, PacketNumberLen: 2, }, make([]byte, protocol.MinInitialPacketSize-100), ) require.Greater(t, len(p.data), protocol.MinInitialPacketSize-100) require.Less(t, len(p.data), protocol.MinInitialPacketSize) testServerDroppedPacket(t, conn, p, protocol.Version1, qlog.PacketTypeInitial, qlog.PacketDropUnexpectedPacket, ) }) // we should not send a Version Negotiation packet if the packet is smaller than 1200 bytes t.Run("packet of unknown version, too small", func(t *testing.T) { conn := newUDPConnLocalhost(t) p := getLongHeaderPacket(t, conn.LocalAddr(), &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, DestConnectionID: randConnID(8), Version: 0x42, }, PacketNumberLen: 2, }, make([]byte, protocol.MinUnknownVersionPacketSize-100), ) require.Greater(t, len(p.data), protocol.MinUnknownVersionPacketSize-100) require.Less(t, len(p.data), protocol.MinUnknownVersionPacketSize) testServerDroppedPacket(t, conn, p, 0x42, "", qlog.PacketDropUnexpectedPacket, ) }) t.Run("not an Initial packet", func(t *testing.T) { conn := newUDPConnLocalhost(t) testServerDroppedPacket(t, conn, getLongHeaderPacket(t, conn.LocalAddr(), &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1, }, PacketNumberLen: 2, }, nil, ), protocol.Version1, qlog.PacketTypeHandshake, qlog.PacketDropUnexpectedPacket, ) }) // as a server, we should never receive a Version Negotiation packet t.Run("Version Negotiation packet", func(t *testing.T) { conn := newUDPConnLocalhost(t) data := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, []protocol.Version{1, 2, 3}, ) testServerDroppedPacket(t, conn, receivedPacket{ remoteAddr: conn.LocalAddr(), data: data, buffer: getPacketBuffer(), }, 0, // version negotiation packets don't have a version qlog.PacketTypeVersionNegotiation, qlog.PacketDropUnexpectedPacket, ) }) } func testServerDroppedPacket(t *testing.T, conn *net.UDPConn, p receivedPacket, expectedVersion qlog.Version, expectedPacketType qlog.PacketType, expectedDropReason qlog.PacketDropReason, ) { readChan := make(chan struct{}) go func() { defer close(readChan) conn.ReadFrom(make([]byte, 1000)) }() var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{eventRecorder: &eventRecorder}) server.handlePacket(p) select { case <-readChan: t.Fatal("didn't expect to receive a packet") case <-time.After(scaleDuration(5 * time.Millisecond)): } var expectedPacketNumber protocol.PacketNumber if expectedPacketType != qlog.PacketTypeVersionNegotiation && expectedPacketType != "" { expectedPacketNumber = protocol.InvalidPacketNumber } require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: expectedPacketType, PacketNumber: expectedPacketNumber, Version: expectedVersion, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: expectedDropReason, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) } func TestServerVersionNegotiation(t *testing.T) { t.Run("enabled", func(t *testing.T) { testServerVersionNegotiation(t, true) }) t.Run("disabled", func(t *testing.T) { testServerVersionNegotiation(t, false) }) } func testServerVersionNegotiation(t *testing.T, enabled bool) { conn := newUDPConnLocalhost(t) var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ eventRecorder: &eventRecorder, disableVersionNegotiation: !enabled, }) srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getLongHeaderPacket(t, conn.LocalAddr(), &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Version: 0x42, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, protocol.MinUnknownVersionPacketSize), ) written := make(chan []byte, 1) go func() { b := make([]byte, 1500) n, _, _ := conn.ReadFrom(b) written <- b[:n] }() server.handlePacket(packet) switch enabled { case true: select { case b := <-written: require.True(t, wire.IsVersionNegotiationPacket(b)) dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) require.NoError(t, err) require.Equal(t, protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), dest) require.Equal(t, protocol.ArbitraryLenConnectionID(destConnID.Bytes()), src) require.NotContains(t, versions, protocol.Version(0x42)) require.Equal(t, []qlogwriter.Event{ qlog.VersionNegotiationSent{ Header: qlog.PacketHeaderVersionNegotiation{ SrcConnectionID: src, DestConnectionID: dest, }, SupportedVersions: server.config.Versions, }, }, eventRecorder.Events(qlog.VersionNegotiationSent{}), ) case <-time.After(time.Second): t.Fatal("timeout") } case false: select { case <-written: t.Fatal("expected no version negotiation packet") case <-time.After(scaleDuration(10 * time.Millisecond)): require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{Version: 0x42}, Raw: qlog.RawInfo{Length: int(packet.Size())}, Trigger: qlog.PacketDropUnexpectedVersion, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) } } } func TestServerRetry(t *testing.T) { var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{eventRecorder: &eventRecorder, useRetry: true}) conn := newUDPConnLocalhost(t) packet := getLongHeaderPacket(t, conn.LocalAddr(), &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, protocol.MinUnknownVersionPacketSize), ) server.handlePacket(packet) checkRetry(t, conn, &eventRecorder, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})) } func TestServerTokenValidation(t *testing.T) { var tokenGeneratorKey handshake.TokenProtectorKey rand.Read(tokenGeneratorKey[:]) tg := handshake.NewTokenGenerator(tokenGeneratorKey) t.Run("retry token with invalid address", func(t *testing.T) { token, err := tg.NewRetryToken( &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, protocol.ConnectionID{}, protocol.ConnectionID{}, ) require.NoError(t, err) var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ useRetry: true, eventRecorder: &eventRecorder, tokenGeneratorKey: tokenGeneratorKey, }) testServerTokenValidation(t, server, &eventRecorder, newUDPConnLocalhost(t), token, false, true, false) }) t.Run("expired retry token", func(t *testing.T) { conn := newUDPConnLocalhost(t) var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ useRetry: true, eventRecorder: &eventRecorder, config: &Config{HandshakeIdleTimeout: time.Millisecond / 2}, tokenGeneratorKey: tokenGeneratorKey, }) token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{}) require.NoError(t, err) // the maximum retry token age is equivalent to the handshake timeout time.Sleep(time.Millisecond) // make sure the token is expired testServerTokenValidation(t, server, &eventRecorder, conn, token, false, true, false) }) // if the packet is corrupted, it will just be dropped (no INVALID_TOKEN nor Retry is sent) t.Run("corrupted packet", func(t *testing.T) { var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ useRetry: true, eventRecorder: &eventRecorder, config: &Config{HandshakeIdleTimeout: time.Millisecond / 2}, tokenGeneratorKey: tokenGeneratorKey, }) conn := newUDPConnLocalhost(t) token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{}) require.NoError(t, err) time.Sleep(time.Millisecond) // make sure the token is expired testServerTokenValidation(t, server, &eventRecorder, conn, token, true, false, true) }) t.Run("invalid non-retry token", func(t *testing.T) { var tokenGeneratorKey2 handshake.TokenProtectorKey rand.Read(tokenGeneratorKey2[:]) var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ tokenGeneratorKey: tokenGeneratorKey2, // use a different key useRetry: true, eventRecorder: &eventRecorder, maxTokenAge: time.Millisecond, }) conn := newUDPConnLocalhost(t) token, err := tg.NewToken(conn.LocalAddr(), 10*time.Millisecond) require.NoError(t, err) time.Sleep(3 * time.Millisecond) // make sure the token is expired testServerTokenValidation(t, server, &eventRecorder, conn, token, false, false, true) }) t.Run("expired non-retry token", func(t *testing.T) { var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ tokenGeneratorKey: tokenGeneratorKey, useRetry: true, eventRecorder: &eventRecorder, maxTokenAge: time.Millisecond, }) conn := newUDPConnLocalhost(t) token, err := tg.NewToken(conn.LocalAddr(), 100*time.Millisecond) require.NoError(t, err) time.Sleep(3 * time.Millisecond) // make sure the token is expired testServerTokenValidation(t, server, &eventRecorder, conn, token, false, false, true) }) } func testServerTokenValidation( t *testing.T, server *testServer, eventRecorder *events.Recorder, conn *net.UDPConn, token []byte, corruptedPacket bool, expectInvalidTokenConnectionClose bool, expectRetry bool, ) { hdr := wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16, Version: protocol.Version1, } packet := getLongHeaderPacketEncrypted(t, conn.LocalAddr(), &wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4}, make([]byte, protocol.MinInitialPacketSize), ) if corruptedPacket { packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet server.handlePacket(packet) require.Eventually(t, func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, time.Second, 10*time.Millisecond, ) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeInitial, PacketNumber: protocol.InvalidPacketNumber, Version: hdr.Version, }, Raw: qlog.RawInfo{Length: int(packet.Size())}, Trigger: qlog.PacketDropPayloadDecryptError, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) return } server.handlePacket(packet) if expectInvalidTokenConnectionClose { checkConnectionClose(t, conn, eventRecorder, hdr.DestConnectionID, hdr.SrcConnectionID, qerr.InvalidToken) } if expectRetry { checkRetry(t, conn, eventRecorder, hdr.SrcConnectionID) } } type connConstructorArgs struct { ctx context.Context connRunner connRunner config *Config origDestConnID protocol.ConnectionID retrySrcConnID *protocol.ConnectionID clientDestConnID protocol.ConnectionID destConnID protocol.ConnectionID srcConnID protocol.ConnectionID } type connConstructorRecorder struct { ch chan connConstructorArgs hooks []*connTestHooks } func newConnConstructorRecorder(hooks ...*connTestHooks) *connConstructorRecorder { return &connConstructorRecorder{ ch: make(chan connConstructorArgs, len(hooks)), hooks: hooks, } } func (r *connConstructorRecorder) Args() <-chan connConstructorArgs { return r.ch } func (r *connConstructorRecorder) NewConn( ctx context.Context, _ context.CancelCauseFunc, _ sendConn, connRunner connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, _ ConnectionIDGenerator, _ *statelessResetter, config *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ time.Duration, _ qlogwriter.Trace, _ utils.Logger, _ protocol.Version, ) *wrappedConn { r.ch <- connConstructorArgs{ ctx: ctx, connRunner: connRunner, config: config, origDestConnID: origDestConnID, retrySrcConnID: retrySrcConnID, clientDestConnID: clientDestConnID, destConnID: destConnID, srcConnID: srcConnID, } hooks := r.hooks[0] r.hooks = r.hooks[1:] return &wrappedConn{testHooks: hooks} } func TestServerCreateConnection(t *testing.T) { t.Run("without retry", func(t *testing.T) { testServerCreateConnection(t, false) }) t.Run("with retry", func(t *testing.T) { testServerCreateConnection(t, true) }) } func testServerCreateConnection(t *testing.T, useRetry bool) { tokenGeneratorKey := TokenGeneratorKey{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} tg := handshake.NewTokenGenerator(tokenGeneratorKey) server := newTestServer(t, &serverOpts{ useRetry: useRetry, tokenGeneratorKey: tokenGeneratorKey, }) done := make(chan struct{}, 3) handledPackets := make(chan receivedPacket, 1) recorder := newConnConstructorRecorder(&connTestHooks{ run: func() error { done <- struct{}{}; return nil }, context: func() context.Context { done <- struct{}{}; return context.Background() }, handshakeComplete: func() <-chan struct{} { done <- struct{}{}; return make(chan struct{}) }, handlePacket: func(p receivedPacket) { handledPackets <- p }, }) server.newConn = recorder.NewConn conn := newUDPConnLocalhost(t) var token []byte if useRetry { var err error token, err = tg.NewRetryToken( conn.LocalAddr(), protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), ) require.NoError(t, err) } hdr := wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16, Token: token, Version: protocol.Version1, } packet := getLongHeaderPacketEncrypted(t, conn.LocalAddr(), &wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4}, make([]byte, protocol.MinInitialPacketSize), ) server.handlePacket(packet) select { case p := <-handledPackets: require.Equal(t, packet, p) case <-time.After(time.Second): t.Fatal("timeout") } var args connConstructorArgs select { case args = <-recorder.Args(): case <-time.After(time.Second): t.Fatal("timeout") } assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID) assert.NotEqual(t, args.origDestConnID, args.srcConnID) if useRetry { assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID) assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), args.origDestConnID) assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), *args.retrySrcConnID) } else { assert.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), args.origDestConnID) assert.Zero(t, args.retrySrcConnID) } for range 3 { select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } } func TestServerClose(t *testing.T) { var hooks []*connTestHooks const numConns = 3 done := make(chan struct{}, numConns) for range numConns { hooks = append(hooks, &connTestHooks{ closeWithTransportError: func(TransportErrorCode) { done <- struct{}{} }, }) } recorder := newConnConstructorRecorder(hooks...) server := newTestServer(t, &serverOpts{newConn: recorder.NewConn}) for range numConns { b := make([]byte, 10) rand.Read(b) connID := protocol.ParseConnectionID(b) server.handlePacket(getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), connID, )) select { case <-recorder.Args(): case <-time.After(time.Second): t.Fatal("timeout") } } server.Close() // closing closes all handshaking connections with CONNECTION_REFUSED for range numConns { select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } // Accept returns ErrServerClosed after closing for range 5 { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err := server.Accept(ctx) require.ErrorIs(t, err, ErrServerClosed) require.ErrorIs(t, err, net.ErrClosed) } } func TestServerGetConfigForClientAccept(t *testing.T) { recorder := newConnConstructorRecorder(&connTestHooks{}) server := newTestServer(t, &serverOpts{ config: &Config{ GetConfigForClient: func(*ClientInfo) (*Config, error) { return &Config{MaxIncomingStreams: 1234}, nil }, }, newConn: recorder.NewConn, }) conn := newUDPConnLocalhost(t) packet := getValidInitialPacket(t, conn.LocalAddr(), protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), ) server.handlePacket(packet) var args connConstructorArgs select { case args = <-recorder.Args(): require.EqualValues(t, 1234, args.config.MaxIncomingStreams) case <-time.After(time.Second): t.Fatal("timeout") } assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID) assert.NotEqual(t, args.origDestConnID, args.srcConnID) } func TestServerGetConfigForClientReject(t *testing.T) { var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ eventRecorder: &eventRecorder, config: &Config{ GetConfigForClient: func(*ClientInfo) (*Config, error) { return nil, errors.New("rejected") }, }, }) conn := newUDPConnLocalhost(t) srcConnID := randConnID(6) destConnID := randConnID(8) server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), srcConnID, destConnID)) checkConnectionClose(t, conn, &eventRecorder, destConnID, srcConnID, qerr.ConnectionRefused) } func TestServerReceiveQueue(t *testing.T) { var eventRecorder events.Recorder acceptConn := make(chan struct{}) defer close(acceptConn) newConnChan := make(chan struct{}, protocol.MaxServerUnprocessedPackets+2) server := newTestServer(t, &serverOpts{ eventRecorder: &eventRecorder, newConn: func( _ context.Context, _ context.CancelCauseFunc, _ sendConn, _ connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ *statelessResetter, _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, _ bool, _ time.Duration, _ qlogwriter.Trace, _ utils.Logger, _ protocol.Version, ) *wrappedConn { newConnChan <- struct{}{} <-acceptConn return &wrappedConn{testHooks: &connTestHooks{handlePacket: func(receivedPacket) {}}} }, }) conn := newUDPConnLocalhost(t) for i := range protocol.MaxServerUnprocessedPackets + 1 { server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8))) // newConn blocks on the acceptConn channel, so this blocks the server's run loop if i == 0 { select { case <-newConnChan: case <-time.After(time.Second): t.Fatal("timeout") } } } p := getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8)) server.handlePacket(p) require.Eventually(t, func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, time.Second, 10*time.Millisecond, ) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) } func TestServerAccept(t *testing.T) { t.Run("without accept early", func(t *testing.T) { testServerAccept(t, false) }) t.Run("with accept early", func(t *testing.T) { testServerAccept(t, true) }) } func testServerAccept(t *testing.T, acceptEarly bool) { ready := make(chan struct{}) hooks := &connTestHooks{} if acceptEarly { hooks.earlyConnReady = func() <-chan struct{} { return ready } } else { hooks.handshakeComplete = func() <-chan struct{} { return ready } } recorder := newConnConstructorRecorder(hooks) server := newTestServer(t, &serverOpts{ acceptEarly: acceptEarly, newConn: recorder.NewConn, }) // Accept should respect the context ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := server.Accept(ctx) require.ErrorIs(t, err, context.Canceled) // establish a new connection, which then starts handshaking server.handlePacket(getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), randConnID(8), )) accepted := make(chan error, 1) go func() { _, err := server.Accept(context.Background()) accepted <- err }() select { case <-accepted: t.Fatal("server accepted the connection too early") case <-time.After(scaleDuration(5 * time.Millisecond)): } // now complete the handshake close(ready) select { case err := <-accepted: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestServerAcceptHandshakeFailure(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) recorder := newConnConstructorRecorder(&connTestHooks{ context: func() context.Context { return ctx }, handshakeComplete: func() <-chan struct{} { return make(chan struct{}) }, }) server := newTestServer(t, &serverOpts{newConn: recorder.NewConn}) // establish a new connection, which then starts handshaking server.handlePacket(getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), randConnID(8), )) accepted := make(chan error, 1) go func() { _, err := server.Accept(context.Background()) accepted <- err }() cancel() select { case <-accepted: t.Fatal("server should not have accepted the connection") case <-time.After(scaleDuration(5 * time.Millisecond)): } } func TestServerAcceptQueue(t *testing.T) { var conns []*connTestHooks rejectedCloseError := make(chan TransportErrorCode, 1) for i := range protocol.MaxAcceptQueueSize + 2 { conn := &connTestHooks{ handshakeComplete: func() <-chan struct{} { c := make(chan struct{}) close(c) return c }, } conns = append(conns, conn) if i == protocol.MaxAcceptQueueSize { conn.closeWithTransportError = func(code TransportErrorCode) { rejectedCloseError <- code } continue } } recorder := newConnConstructorRecorder(conns...) server := newTestServer(t, &serverOpts{newConn: recorder.NewConn}) for range protocol.MaxAcceptQueueSize { b := make([]byte, 16) rand.Read(b) connID := protocol.ParseConnectionID(b) server.handlePacket( getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), connID), ) select { case args := <-recorder.Args(): require.Equal(t, connID, args.origDestConnID) case <-time.After(time.Second): t.Fatal("timeout") } } // wait for the connection to be enqueued time.Sleep(scaleDuration(10 * time.Millisecond)) server.handlePacket( getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), randConnID(8)), ) select { case <-recorder.Args(): case <-time.After(time.Second): t.Fatal("timeout") } select { case code := <-rejectedCloseError: require.Equal(t, ConnectionRefused, code) case <-time.After(time.Second): t.Fatal("timeout") } // accept one connection, freeing up one slot in the accept queue ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err := server.Accept(ctx) require.NoError(t, err) // it's now possible to enqueue a new connection server.handlePacket( getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), ), ) select { case args := <-recorder.Args(): require.Equal(t, protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), args.origDestConnID) case <-time.After(time.Second): t.Fatal("timeout") } } func TestServer0RTTReordering(t *testing.T) { var eventRecorder events.Recorder packets := make(chan receivedPacket, protocol.Max0RTTQueueLen+1) done := make(chan struct{}) recorder := newConnConstructorRecorder(&connTestHooks{ handlePacket: func(p receivedPacket) { packets <- p }, earlyConnReady: func() <-chan struct{} { return make(chan struct{}) }, run: func() error { close(done); return nil }, }) server := newTestServer(t, &serverOpts{ acceptEarly: true, eventRecorder: &eventRecorder, newConn: recorder.NewConn, }) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) var zeroRTTPackets []receivedPacket for range protocol.Max0RTTQueueLen { p := getLongHeaderPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketType0RTT, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: connID, Length: 100, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, 100), ) server.handlePacket(p) zeroRTTPackets = append(zeroRTTPackets, p) } // send one more packet, this one should be dropped p := getLongHeaderPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketType0RTT, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: connID, Length: 100, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, 100), ) server.handlePacket(p) require.Eventually(t, func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, time.Second, 10*time.Millisecond, ) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) // now receive the Initial initial := getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(5), connID) server.handlePacket(initial) for i := range protocol.Max0RTTQueueLen + 1 { select { case p := <-packets: if i == 0 { require.Equal(t, initial.data, p.data) } else { require.Equal(t, zeroRTTPackets[i-1], p) } case <-time.After(time.Second): t.Fatal("timeout") } } select { case <-done: case <-time.After(time.Second): t.Fatal("timeout") } } func TestServer0RTTQueueing(t *testing.T) { var eventRecorder events.Recorder server := newTestServer(t, &serverOpts{ acceptEarly: true, eventRecorder: &eventRecorder, }) firstRcvTime := monotime.Now() otherRcvTime := firstRcvTime.Add(protocol.Max0RTTQueueingDuration / 2) var sizes []protocol.ByteCount for i := range protocol.Max0RTTQueues { b := make([]byte, 16) rand.Read(b) connID := protocol.ParseConnectionID(b) size := protocol.ByteCount(500 + i) p := getLongHeaderPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketType0RTT, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: connID, Length: size, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, size), ) if i == 0 { p.rcvTime = firstRcvTime } else { p.rcvTime = otherRcvTime } sizes = append(sizes, p.Size()) server.handlePacket(p) } // maximum number of 0-RTT queues is reached, further packets are dropped p := getLongHeaderPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketType0RTT, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Length: 123, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, 123), ) server.handlePacket(p) require.Eventually(t, func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 }, time.Second, 10*time.Millisecond, ) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) eventRecorder.Clear() // There's no cleanup Go routine. // Cleanup is triggered when new packets are received. // 1. Receive one handshake packet, which triggers the cleanup of the first 0-RTT queue triggerPacket := getLongHeaderPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), Length: 123, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, 123), ) triggerPacket.rcvTime = firstRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) server.handlePacket(triggerPacket) require.Eventually(t, func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) == 2 }, time.Second, 10*time.Millisecond, ) require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, PacketNumber: protocol.InvalidPacketNumber, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(triggerPacket.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }, qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(sizes[0])}, Trigger: qlog.PacketDropDOSPrevention, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) eventRecorder.Clear() // 2. Receive another handshake packet, which triggers the cleanup of the other 0-RTT queues triggerPacket = getLongHeaderPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), Length: 124, Version: protocol.Version1, }, PacketNumberLen: protocol.PacketNumberLen4, }, make([]byte, 124), ) triggerPacket.rcvTime = otherRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) server.handlePacket(triggerPacket) expectedEvents := []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketTypeHandshake, PacketNumber: protocol.InvalidPacketNumber, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(triggerPacket.Size())}, Trigger: qlog.PacketDropUnexpectedPacket, }, } for i := range protocol.Max0RTTQueues - 1 { expectedEvents = append(expectedEvents, qlog.PacketDropped{ Header: qlog.PacketHeader{ PacketType: qlog.PacketType0RTT, PacketNumber: protocol.InvalidPacketNumber, Version: protocol.Version1, }, Raw: qlog.RawInfo{Length: int(sizes[i+1])}, Trigger: qlog.PacketDropDOSPrevention, }) } require.Eventually(t, func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) == len(expectedEvents) }, time.Second, 10*time.Millisecond, ) // queues are dropped in random order for _, event := range expectedEvents { require.Contains(t, eventRecorder.Events(qlog.PacketDropped{}), event) } } quic-go-0.59.0/sni.go000066400000000000000000000070711513066070600143020ustar00rootroot00000000000000package quic import ( "encoding/binary" "errors" "io" ) const ( extTypeSNI = 0 extTypeECH = 0xfe0d ) // findSNIAndECH parses the given byte slice as a ClientHello, and locates: // - the position and length of the Server Name Indication (SNI) extension, // - the position of the Encrypted Client Hello (ECH) extension. // If no SNI extension is found, it returns -1 for the SNI position. // If no ECH extension is found, it returns -1 for the ECH position. func findSNIAndECH(data []byte) (sniPos, sniLen, echPos int, err error) { if len(data) < 4 { return 0, 0, 0, io.ErrUnexpectedEOF } if data[0] != 1 { return 0, 0, 0, errors.New("not a ClientHello") } handshakeLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if len(data) != 4+handshakeLen { return 0, 0, 0, io.ErrUnexpectedEOF } parsePos := 4 // Skip protocol version (2 bytes) if parsePos+2 > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } parsePos += 2 // skip random (32 bytes) if parsePos+32 > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } parsePos += 32 // session ID if parsePos+1 > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } sessionIDLen := int(data[parsePos]) parsePos++ if parsePos+sessionIDLen > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } parsePos += sessionIDLen // cipher suites if parsePos+2 > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } cipherSuitesLen := int(binary.BigEndian.Uint16(data[parsePos:])) parsePos += 2 if parsePos+cipherSuitesLen > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } parsePos += cipherSuitesLen // compression methods if parsePos+1 > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } compressionMethodsLen := int(data[parsePos]) parsePos++ if parsePos+compressionMethodsLen > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } parsePos += compressionMethodsLen // extensions if parsePos+2 > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } extensionsLen := int(binary.BigEndian.Uint16(data[parsePos:])) parsePos += 2 if parsePos+extensionsLen > len(data) { return 0, 0, 0, io.ErrUnexpectedEOF } extensionsStart := parsePos extensions := data[extensionsStart : extensionsStart+extensionsLen] // parse extensions var extPos int sniPos = -1 echPos = -1 for extPos+4 <= extensionsLen { extType := binary.BigEndian.Uint16(extensions[extPos:]) extLen := int(binary.BigEndian.Uint16(extensions[extPos+2:])) if extPos+4+extLen > extensionsLen { return 0, 0, 0, io.ErrUnexpectedEOF } switch extType { case extTypeSNI: if sniPos != -1 { return 0, 0, 0, errors.New("multiple SNI extensions") } sniData := extensions[extPos+4 : extPos+4+extLen] if len(sniData) < 2 { return 0, 0, 0, io.ErrUnexpectedEOF } nameListLen := int(binary.BigEndian.Uint16(sniData)) if len(sniData) != 2+nameListLen { return 0, 0, 0, io.ErrUnexpectedEOF } listPos := 2 for listPos+3 <= nameListLen+2 { nameType := sniData[listPos] sniLen = int(binary.BigEndian.Uint16(sniData[listPos+1:])) if listPos+3+sniLen > len(sniData) { return 0, 0, 0, io.ErrUnexpectedEOF } if nameType == 0 { // host_name sniPos = extensionsStart + extPos + 4 + listPos + 3 break // stop after first host_name } listPos += 3 + sniLen } if sniPos == 0 { return 0, 0, 0, errors.New("SNI host_name not found") } case extTypeECH: if echPos != -1 { return 0, 0, 0, errors.New("multiple ECH extensions") } echPos = extensionsStart + extPos } extPos += 4 + extLen if sniPos != -1 && echPos != -1 { break } } return sniPos, sniLen, echPos, nil } quic-go-0.59.0/sni_go124_test.go000066400000000000000000000163371513066070600162620ustar00rootroot00000000000000//go:build go1.24 package quic import ( "context" "crypto/ecdh" "crypto/rand" "crypto/tls" "encoding/binary" "io" mrand "math/rand/v2" "testing" "golang.org/x/crypto/cryptobyte" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func getClientHelloWithECH(t testing.TB, serverName string) []byte { t.Helper() // various constants from the standard library's (internal) hpke package const ( DHKEM_X25519_HKDF_SHA256 = 0x20 KDF_HKDF_SHA256 = 1 AEAD_AES_128_GCM = 1 ) marshalECHConfig := func(id uint8, pubKey []byte, publicName string, maxNameLen uint8) []byte { builder := cryptobyte.NewBuilder(nil) builder.AddUint16(extTypeECH) builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddUint8(id) builder.AddUint16(DHKEM_X25519_HKDF_SHA256) builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddBytes(pubKey) }) builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddUint16(KDF_HKDF_SHA256) builder.AddUint16(AEAD_AES_128_GCM) }) builder.AddUint8(maxNameLen) builder.AddUint8LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddBytes([]byte(publicName)) }) builder.AddUint16(0) // extensions }) return builder.BytesOrPanic() } echKey, err := ecdh.X25519().GenerateKey(rand.Reader) require.NoError(t, err) echConfig := marshalECHConfig(42, echKey.PublicKey().Bytes(), serverName, 32) builder := cryptobyte.NewBuilder(nil) builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddBytes(echConfig) }) c := tls.QUICClient(&tls.QUICConfig{ TLSConfig: &tls.Config{ ServerName: serverName, MinVersion: tls.VersionTLS13, EncryptedClientHelloConfigList: builder.BytesOrPanic(), InsecureSkipVerify: serverName == "", // disable post-quantum curves CurvePreferences: []tls.CurveID{tls.CurveP256}, }, }) b := make([]byte, mrand.IntN(200)) rand.Read(b) c.SetTransportParameters(b) require.NoError(t, c.Start(context.Background())) ev := c.NextEvent() require.Equal(t, tls.QUICWriteData, ev.Kind) checkClientHello(t, ev.Data) return ev.Data } // shuffleClientHelloExtensions takes a TLS 1.3 ClientHello message (without the record layer) // and returns a new ClientHello with its extensions shuffled. Returns nil if the input is invalid. func shuffleClientHelloExtensions(t testing.TB, clientHello []byte) []byte { t.Helper() // Basic validation: ensure minimum length and correct handshake type (0x01 for ClientHello) if len(clientHello) < 4 || clientHello[0] != 0x01 { t.Fatalf("not a ClientHello") } // Extract the 3-byte length (24-bit integer) and validate total length length := uint32(clientHello[1])<<16 | uint32(clientHello[2])<<8 | uint32(clientHello[3]) require.Equal(t, 4+int(length), len(clientHello)) // Body is everything after type and length body := clientHello[4 : 4+length] var pos int // Parse fixed and variable-length fields to reach extensions require.Greater(t, len(body), pos+2) // protocol version: 2 bytes pos += 2 require.Greater(t, len(body), pos+32) // random: 32 bytes pos += 32 require.Greater(t, len(body), pos+1) // session ID length: 1 byte sessionIDLen := int(body[pos]) pos += 1 require.Greater(t, len(body), pos+sessionIDLen) // session ID data pos += sessionIDLen require.Greater(t, len(body), pos+2) // cipher suites length: 2 bytes cipherSuitesLen := int(body[pos])<<8 | int(body[pos+1]) pos += 2 require.Greater(t, len(body), pos+cipherSuitesLen) // cipher suites data pos += cipherSuitesLen require.Greater(t, len(body), pos+1) // compression methods length: 1 byte compressionMethodsLen := int(body[pos]) pos += 1 require.Greater(t, len(body), pos+compressionMethodsLen) // compression methods data pos += compressionMethodsLen // Extensions: 2 bytes total length + data (may be absent) if pos+2 > len(body) { // No extensions present; return original return clientHello } extensionsLen := int(body[pos])<<8 | int(body[pos+1]) pos += 2 require.Equal(t, pos+extensionsLen, len(body)) // extensions length doesn't match remaining data extensionsData := body[pos : pos+extensionsLen] // parse extensions into a slice of byte slices var extensions [][]byte var extPos int for extPos < extensionsLen { require.Greater(t, extensionsLen, extPos+4) // type and length extLen := int(extensionsData[extPos+2])<<8 | int(extensionsData[extPos+3]) require.LessOrEqual(t, extPos+4+extLen, extensionsLen) // extension exceeds total length // extract entire extension (type: 2 bytes, length: 2 bytes, data) extData := extensionsData[extPos : extPos+4+extLen] extensions = append(extensions, extData) extPos += 4 + extLen } // shuffle extensions using a proper random source mrand.Shuffle(len(extensions), func(i, j int) { extensions[i], extensions[j] = extensions[j], extensions[i] }) // reconstruct extensions data var newExtensionsData []byte for _, ext := range extensions { newExtensionsData = append(newExtensionsData, ext...) } // reconstruct body: prefix (up to and including extensions length) + shuffled extensions prefix := body[:pos] newBody := append(prefix, newExtensionsData...) // reconstruct ClientHello: type (0x01) + original length + new body newClientHello := []byte{0x01} lengthBytes := clientHello[1:4] // length unchanged since only extensions are shuffled newClientHello = append(newClientHello, lengthBytes...) newClientHello = append(newClientHello, newBody...) // check that it's actually valid checkClientHello(t, newClientHello) return newClientHello } func TestFindSNIWithECH(t *testing.T) { const serverName = "public.example" clientHello := shuffleClientHelloExtensions(t, getClientHelloWithECH(t, serverName)) sniPos, sniLen, echPos, err := findSNIAndECH(clientHello) require.NoError(t, err) require.NotEqual(t, -1, echPos) require.Equal(t, uint16(extTypeECH), binary.BigEndian.Uint16(clientHello[echPos:echPos+2])) assert.Equal(t, len(serverName), sniLen) require.NotEqual(t, -1, sniPos) require.Equal(t, serverName, string(clientHello[sniPos:sniPos+sniLen])) for i := range clientHello { _, _, _, err := findSNIAndECH(clientHello[:i]) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } } // findSNI is never run with attacker-controlled inputs (other than the session ticket), // so this is not a high-value target to begin with, // and doesn't need to be run in ClusterFuzz. // It's still useful to find potential corner cases in the parser. func FuzzFindSNI(f *testing.F) { f.Add(getClientHello(f, ""), 10) f.Add(getClientHello(f, "google.com"), 20) f.Add(getClientHello(f, "sub.do.ma.in.quic-go.net"), 30) f.Add(getClientHelloWithECH(f, "quic-go.net"), 40) f.Fuzz(func(t *testing.T, data []byte, maxSize int) { cs := newInitialCryptoStream(true) if _, err := cs.Write(data); err != nil { return } segments := make(map[protocol.ByteCount][]byte) if !cs.HasData() { // incomplete ClientHello return } for cs.HasData() { f := cs.PopCryptoFrame(5 + protocol.ByteCount(maxSize)) if f == nil { return } segments[f.Offset] = f.Data } reassembled := reassembleCryptoData(t, segments) require.Equal(t, data, reassembled) }) } quic-go-0.59.0/sni_test.go000066400000000000000000000037741513066070600153470ustar00rootroot00000000000000package quic import ( "context" "crypto/rand" "crypto/tls" "io" mrand "math/rand/v2" "testing" "github.com/quic-go/quic-go/internal/testdata" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func checkClientHello(t testing.TB, clientHello []byte) { t.Helper() conn := tls.QUICServer(&tls.QUICConfig{ TLSConfig: testdata.GetTLSConfig(), }) require.NoError(t, conn.Start(context.Background())) defer conn.Close() require.NoError(t, conn.HandleData(tls.QUICEncryptionLevelInitial, clientHello)) } func getClientHello(t testing.TB, serverName string) []byte { t.Helper() c := tls.QUICClient(&tls.QUICConfig{ TLSConfig: &tls.Config{ ServerName: serverName, MinVersion: tls.VersionTLS13, InsecureSkipVerify: serverName == "", // disable post-quantum curves CurvePreferences: []tls.CurveID{tls.CurveP256}, }, }) b := make([]byte, mrand.IntN(200)) rand.Read(b) c.SetTransportParameters(b) require.NoError(t, c.Start(context.Background())) ev := c.NextEvent() require.Equal(t, tls.QUICWriteData, ev.Kind) checkClientHello(t, ev.Data) return ev.Data } func TestFindSNI(t *testing.T) { t.Run("without SNI", func(t *testing.T) { testFindSNI(t, "") }) t.Run("without subdomain", func(t *testing.T) { testFindSNI(t, "quic-go.net") }) t.Run("with subdomain", func(t *testing.T) { testFindSNI(t, "sub.do.ma.in.quic-go.net") }) } func testFindSNI(t *testing.T, serverName string) { clientHello := getClientHello(t, serverName) sniPos, sniLen, echPos, err := findSNIAndECH(clientHello) require.NoError(t, err) assert.Equal(t, -1, echPos) if serverName == "" { require.Equal(t, -1, sniPos) return } assert.Equal(t, len(serverName), sniLen) require.NotEqual(t, -1, sniPos) require.Equal(t, serverName, string(clientHello[sniPos:sniPos+sniLen])) // incomplete ClientHellos result in an io.ErrUnexpectedEOF for i := range clientHello { _, _, _, err := findSNIAndECH(clientHello[:i]) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } } quic-go-0.59.0/stateless_reset.go000066400000000000000000000017051513066070600167200ustar00rootroot00000000000000package quic import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "hash" "sync" "github.com/quic-go/quic-go/internal/protocol" ) type statelessResetter struct { mx sync.Mutex h hash.Hash } // newStatelessRetter creates a new stateless reset generator. // It is valid to use a nil key. In that case, a random key will be used. // This makes is impossible for on-path attackers to shut down established connections. func newStatelessResetter(key *StatelessResetKey) *statelessResetter { var h hash.Hash if key != nil { h = hmac.New(sha256.New, key[:]) } else { b := make([]byte, 32) _, _ = rand.Read(b) h = hmac.New(sha256.New, b) } return &statelessResetter{h: h} } func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { r.mx.Lock() defer r.mx.Unlock() var token protocol.StatelessResetToken r.h.Write(connID.Bytes()) copy(token[:], r.h.Sum(nil)) r.h.Reset() return token } quic-go-0.59.0/stateless_reset_test.go000066400000000000000000000020641513066070600177560ustar00rootroot00000000000000package quic import ( "crypto/rand" "testing" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func TestStatelessResetter(t *testing.T) { t.Run("no key", func(t *testing.T) { r1 := newStatelessResetter(nil) r2 := newStatelessResetter(nil) for i := 0; i < 100; i++ { b := make([]byte, 15) rand.Read(b) connID := protocol.ParseConnectionID(b) t1 := r1.GetStatelessResetToken(connID) t2 := r2.GetStatelessResetToken(connID) require.NotZero(t, t1) require.NotZero(t, t2) require.NotEqual(t, t1, t2) } }) t.Run("with key", func(t *testing.T) { var key StatelessResetKey rand.Read(key[:]) m := newStatelessResetter(&key) b := make([]byte, 8) rand.Read(b) connID := protocol.ParseConnectionID(b) token := m.GetStatelessResetToken(connID) require.NotZero(t, token) require.Equal(t, token, m.GetStatelessResetToken(connID)) // generate a new connection ID rand.Read(b) connID2 := protocol.ParseConnectionID(b) require.NotEqual(t, token, m.GetStatelessResetToken(connID2)) }) } quic-go-0.59.0/stream.go000066400000000000000000000173751513066070600150140ustar00rootroot00000000000000package quic import ( "context" "net" "os" "sync" "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) type deadlineError struct{} func (deadlineError) Error() string { return "deadline exceeded" } func (deadlineError) Temporary() bool { return true } func (deadlineError) Timeout() bool { return true } func (deadlineError) Unwrap() error { return os.ErrDeadlineExceeded } var errDeadline net.Error = &deadlineError{} // The streamSender is notified by the stream about various events. type streamSender interface { onHasConnectionData() onHasStreamData(protocol.StreamID, *SendStream) onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter) // must be called without holding the mutex that is acquired by closeForShutdown onStreamCompleted(protocol.StreamID) } // Each of the both stream halves gets its own uniStreamSender. // This is necessary in order to keep track when both halves have been completed. type uniStreamSender struct { streamSender onStreamCompletedImpl func() onHasStreamControlFrameImpl func(protocol.StreamID, streamControlFrameGetter) } func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str *SendStream) { s.streamSender.onHasStreamData(id, str) } func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() } func (s *uniStreamSender) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) { s.onHasStreamControlFrameImpl(id, str) } var _ streamSender = &uniStreamSender{} type Stream struct { receiveStr *ReceiveStream sendStr *SendStream completedMutex sync.Mutex sender streamSender receiveStreamCompleted bool sendStreamCompleted bool } var ( _ outgoingStream = &Stream{} _ sendStreamFrameHandler = &Stream{} _ receiveStreamFrameHandler = &Stream{} ) // newStream creates a new Stream func newStream( ctx context.Context, streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, supportsResetStreamAt bool, ) *Stream { s := &Stream{sender: sender} senderForSendStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { s.completedMutex.Lock() s.sendStreamCompleted = true s.checkIfCompleted() s.completedMutex.Unlock() }, onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) { sender.onHasStreamControlFrame(streamID, s) }, } s.sendStr = newSendStream(ctx, streamID, senderForSendStream, flowController, supportsResetStreamAt) senderForReceiveStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { s.completedMutex.Lock() s.receiveStreamCompleted = true s.checkIfCompleted() s.completedMutex.Unlock() }, onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) { sender.onHasStreamControlFrame(streamID, s) }, } s.receiveStr = newReceiveStream(streamID, senderForReceiveStream, flowController) return s } // StreamID returns the stream ID. func (s *Stream) StreamID() protocol.StreamID { // the result is same for receiveStream and sendStream return s.sendStr.StreamID() } // Read reads data from the stream. // Read can be made to time out using [Stream.SetReadDeadline] and [Stream.SetDeadline]. // If the stream was canceled, the error is a [StreamError]. func (s *Stream) Read(p []byte) (int, error) { return s.receiveStr.Read(p) } // Peek fills b with stream data, without consuming the stream data. // It blocks until len(b) bytes are available, or an error occurs. // It respects the stream deadline set by SetReadDeadline. // If the stream ends before len(b) bytes are available, // it returns the number of bytes peeked along with io.EOF. func (s *Stream) Peek(b []byte) (int, error) { return s.receiveStr.Peek(b) } // Write writes data to the stream. // Write can be made to time out using [Stream.SetWriteDeadline] or [Stream.SetDeadline]. // If the stream was canceled, the error is a [StreamError]. func (s *Stream) Write(p []byte) (int, error) { return s.sendStr.Write(p) } // SetReliableBoundary marks the data written to this stream so far as reliable. // It is valid to call this function multiple times, thereby increasing the reliable size. // It only has an effect if the peer enabled support for the RESET_STREAM_AT extension, // otherwise, it is a no-op. func (s *Stream) SetReliableBoundary() { s.sendStr.SetReliableBoundary() } // CancelWrite aborts sending on this stream. // See [SendStream.CancelWrite] for more details. func (s *Stream) CancelWrite(errorCode StreamErrorCode) { s.sendStr.CancelWrite(errorCode) } // CancelRead aborts receiving on this stream. // See [ReceiveStream.CancelRead] for more details. func (s *Stream) CancelRead(errorCode StreamErrorCode) { s.receiveStr.CancelRead(errorCode) } // The Context is canceled as soon as the write-side of the stream is closed. // See [SendStream.Context] for more details. func (s *Stream) Context() context.Context { return s.sendStr.Context() } // Close closes the send-direction of the stream. // It does not close the receive-direction of the stream. func (s *Stream) Close() error { return s.sendStr.Close() } func (s *Stream) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime monotime.Time) error { return s.receiveStr.handleResetStreamFrame(frame, rcvTime) } func (s *Stream) handleStreamFrame(frame *wire.StreamFrame, rcvTime monotime.Time) error { return s.receiveStr.handleStreamFrame(frame, rcvTime) } func (s *Stream) handleStopSendingFrame(frame *wire.StopSendingFrame) { s.sendStr.handleStopSendingFrame(frame) } func (s *Stream) updateSendWindow(limit protocol.ByteCount) { s.sendStr.updateSendWindow(limit) } func (s *Stream) enableResetStreamAt() { s.sendStr.enableResetStreamAt() } func (s *Stream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) { return s.sendStr.popStreamFrame(maxBytes, v) } func (s *Stream) getControlFrame(now monotime.Time) (_ ackhandler.Frame, ok, hasMore bool) { f, ok, _ := s.sendStr.getControlFrame(now) if ok { return f, true, true } return s.receiveStr.getControlFrame(now) } // SetReadDeadline sets the deadline for future Read calls. // See [ReceiveStream.SetReadDeadline] for more details. func (s *Stream) SetReadDeadline(t time.Time) error { return s.receiveStr.SetReadDeadline(t) } // SetWriteDeadline sets the deadline for future Write calls. // See [SendStream.SetWriteDeadline] for more details. func (s *Stream) SetWriteDeadline(t time.Time) error { return s.sendStr.SetWriteDeadline(t) } // SetDeadline sets the read and write deadlines associated with the stream. // It is equivalent to calling both SetReadDeadline and SetWriteDeadline. func (s *Stream) SetDeadline(t time.Time) error { _ = s.receiveStr.SetReadDeadline(t) // SetReadDeadline never errors _ = s.sendStr.SetWriteDeadline(t) // SetWriteDeadline never errors return nil } // CloseForShutdown closes a stream abruptly. // It makes Read and Write unblock (and return the error) immediately. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. func (s *Stream) closeForShutdown(err error) { s.sendStr.closeForShutdown(err) s.receiveStr.closeForShutdown(err) } // checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. // It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. func (s *Stream) checkIfCompleted() { if s.sendStreamCompleted && s.receiveStreamCompleted { s.sender.onStreamCompleted(s.StreamID()) } } quic-go-0.59.0/stream_test.go000066400000000000000000000067161513066070600160500ustar00rootroot00000000000000package quic import ( "context" "io" "os" "testing" "time" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func TestStreamDeadlines(t *testing.T) { const streamID protocol.StreamID = 1337 mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newStream(context.Background(), streamID, mockSender, mockFC, false) // SetDeadline sets both read and write deadlines str.SetDeadline(time.Now().Add(-time.Second)) n, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Zero(t, n) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()).AnyTimes() require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, monotime.Now())) n, err = (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 6)) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Zero(t, n) } func TestStreamCompletion(t *testing.T) { completeReadSide := func( t *testing.T, str *Stream, mockCtrl *gomock.Controller, mockFC *mocks.MockStreamFlowController, ) { t.Helper() mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{ StreamID: str.StreamID(), Data: []byte("foobar"), Fin: true, }, monotime.Now())) _, err := (&readerWithTimeout{Reader: str, Timeout: time.Second}).Read(make([]byte, 6)) require.ErrorIs(t, err, io.EOF) require.True(t, mockCtrl.Satisfied()) } completeWriteSide := func( t *testing.T, str *Stream, mockCtrl *gomock.Controller, mockFC *mocks.MockStreamFlowController, mockSender *MockStreamSender, ) { t.Helper() mockSender.EXPECT().onHasStreamData(str.StreamID(), gomock.Any()).Times(2) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, str.Close()) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) f, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) require.NotNil(t, f.Frame) require.True(t, f.Frame.Fin) f.Handler.OnAcked(f.Frame) require.True(t, mockCtrl.Satisfied()) } const streamID protocol.StreamID = 1337 t.Run("first read, then write", func(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newStream(context.Background(), streamID, mockSender, mockFC, false) completeReadSide(t, str, mockCtrl, mockFC) mockSender.EXPECT().onStreamCompleted(streamID) completeWriteSide(t, str, mockCtrl, mockFC, mockSender) }) t.Run("first write, then read", func(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) str := newStream(context.Background(), streamID, mockSender, mockFC, false) completeWriteSide(t, str, mockCtrl, mockFC, mockSender) mockSender.EXPECT().onStreamCompleted(streamID) completeReadSide(t, str, mockCtrl, mockFC) }) } quic-go-0.59.0/streams_map.go000066400000000000000000000237771513066070600160370ustar00rootroot00000000000000package quic import ( "context" "fmt" "sync" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) // StreamLimitReachedError is returned from Conn.OpenStream and Conn.OpenUniStream // when it is not possible to open a new stream because the number of opens streams reached // the peer's stream limit. type StreamLimitReachedError struct{} func (e StreamLimitReachedError) Error() string { return "too many open streams" } type streamsMap struct { ctx context.Context // not used for cancellations, but carries the values associated with the connection perspective protocol.Perspective maxIncomingBidiStreams uint64 maxIncomingUniStreams uint64 sender streamSender queueControlFrame func(wire.Frame) newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController mutex sync.Mutex outgoingBidiStreams *outgoingStreamsMap[*Stream] outgoingUniStreams *outgoingStreamsMap[*SendStream] incomingBidiStreams *incomingStreamsMap[*Stream] incomingUniStreams *incomingStreamsMap[*ReceiveStream] reset bool supportsResetStreamAt bool } func newStreamsMap( ctx context.Context, sender streamSender, queueControlFrame func(wire.Frame), newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, maxIncomingBidiStreams uint64, maxIncomingUniStreams uint64, perspective protocol.Perspective, ) *streamsMap { m := &streamsMap{ ctx: ctx, perspective: perspective, queueControlFrame: queueControlFrame, newFlowController: newFlowController, maxIncomingBidiStreams: maxIncomingBidiStreams, maxIncomingUniStreams: maxIncomingUniStreams, sender: sender, } m.initMaps() return m } func (m *streamsMap) initMaps() { m.outgoingBidiStreams = newOutgoingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *Stream { return newStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt) }, m.queueControlFrame, m.perspective, ) m.incomingBidiStreams = newIncomingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *Stream { return newStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt) }, m.maxIncomingBidiStreams, m.queueControlFrame, m.perspective, ) m.outgoingUniStreams = newOutgoingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *SendStream { return newSendStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt) }, m.queueControlFrame, m.perspective, ) m.incomingUniStreams = newIncomingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *ReceiveStream { return newReceiveStream(id, m.sender, m.newFlowController(id)) }, m.maxIncomingUniStreams, m.queueControlFrame, m.perspective, ) } func (m *streamsMap) OpenStream() (*Stream, error) { m.mutex.Lock() reset := m.reset mm := m.outgoingBidiStreams m.mutex.Unlock() if reset { return nil, Err0RTTRejected } return mm.OpenStream() } func (m *streamsMap) OpenStreamSync(ctx context.Context) (*Stream, error) { m.mutex.Lock() reset := m.reset mm := m.outgoingBidiStreams m.mutex.Unlock() if reset { return nil, Err0RTTRejected } return mm.OpenStreamSync(ctx) } func (m *streamsMap) OpenUniStream() (*SendStream, error) { m.mutex.Lock() reset := m.reset mm := m.outgoingUniStreams m.mutex.Unlock() if reset { return nil, Err0RTTRejected } return mm.OpenStream() } func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (*SendStream, error) { m.mutex.Lock() reset := m.reset mm := m.outgoingUniStreams m.mutex.Unlock() if reset { return nil, Err0RTTRejected } return mm.OpenStreamSync(ctx) } func (m *streamsMap) AcceptStream(ctx context.Context) (*Stream, error) { m.mutex.Lock() reset := m.reset mm := m.incomingBidiStreams m.mutex.Unlock() if reset { return nil, Err0RTTRejected } return mm.AcceptStream(ctx) } func (m *streamsMap) AcceptUniStream(ctx context.Context) (*ReceiveStream, error) { m.mutex.Lock() reset := m.reset mm := m.incomingUniStreams m.mutex.Unlock() if reset { return nil, Err0RTTRejected } return mm.AcceptStream(ctx) } func (m *streamsMap) DeleteStream(id protocol.StreamID) error { switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() == m.perspective { return m.outgoingUniStreams.DeleteStream(id) } return m.incomingUniStreams.DeleteStream(id) case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { return m.outgoingBidiStreams.DeleteStream(id) } return m.incomingBidiStreams.DeleteStream(id) } panic("") } func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { switch f.Type { case protocol.StreamTypeUni: m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum.StreamID(protocol.StreamTypeUni, m.perspective)) case protocol.StreamTypeBidi: m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective)) } } type sendStreamFrameHandler interface { updateSendWindow(protocol.ByteCount) handleStopSendingFrame(*wire.StopSendingFrame) } func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) { switch id.Type() { case protocol.StreamTypeUni: if id.InitiatedBy() != m.perspective { // an outgoing unidirectional stream is a send stream, not a receive stream return nil, &qerr.TransportError{ ErrorCode: qerr.StreamStateError, ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id), } } str, err := m.outgoingUniStreams.GetStream(id) if str == nil || err != nil { return nil, err } return str, nil case protocol.StreamTypeBidi: if id.InitiatedBy() == m.perspective { str, err := m.outgoingBidiStreams.GetStream(id) if str == nil || err != nil { return nil, err } return str, nil } str, err := m.incomingBidiStreams.GetOrOpenStream(id) if str == nil || err != nil { return nil, err } return str, nil } panic("unreachable") } func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error { str, err := m.getSendStream(f.StreamID) if err != nil { return err } if str == nil { // stream already deleted return nil } str.updateSendWindow(f.MaximumStreamData) return nil } func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error { str, err := m.getSendStream(f.StreamID) if err != nil { return err } if str == nil { // stream already deleted return nil } str.handleStopSendingFrame(f) return nil } type receiveStreamFrameHandler interface { handleResetStreamFrame(*wire.ResetStreamFrame, monotime.Time) error handleStreamFrame(*wire.StreamFrame, monotime.Time) error } func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) { switch id.Type() { case protocol.StreamTypeUni: // an outgoing unidirectional stream is a send stream, not a receive stream if id.InitiatedBy() == m.perspective { return nil, &qerr.TransportError{ ErrorCode: qerr.StreamStateError, ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id), } } str, err := m.incomingUniStreams.GetOrOpenStream(id) if err != nil || str == nil { return nil, err } return str, nil case protocol.StreamTypeBidi: var str *Stream var err error if id.InitiatedBy() == m.perspective { str, err = m.outgoingBidiStreams.GetStream(id) } else { str, err = m.incomingBidiStreams.GetOrOpenStream(id) } if str == nil || err != nil { return nil, err } return str, nil } panic("unreachable") } func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error { if _, err := m.getReceiveStream(f.StreamID); err != nil { return err } // We don't need to do anything in response to a STREAM_DATA_BLOCKED frame, // but we need to make sure that the stream ID is valid. return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame } func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime monotime.Time) error { str, err := m.getReceiveStream(f.StreamID) if err != nil { return err } if str == nil { // stream already deleted return nil } return str.handleResetStreamFrame(f, rcvTime) } func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime monotime.Time) error { str, err := m.getReceiveStream(f.StreamID) if err != nil { return err } if str == nil { // stream already deleted return nil } return str.handleStreamFrame(f, rcvTime) } func (m *streamsMap) HandleTransportParameters(p *wire.TransportParameters) { m.supportsResetStreamAt = p.EnableResetStreamAt m.outgoingBidiStreams.EnableResetStreamAt() m.outgoingUniStreams.EnableResetStreamAt() m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective)) m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum.StreamID(protocol.StreamTypeUni, m.perspective)) } func (m *streamsMap) CloseWithError(err error) { m.outgoingBidiStreams.CloseWithError(err) m.outgoingUniStreams.CloseWithError(err) m.incomingBidiStreams.CloseWithError(err) m.incomingUniStreams.CloseWithError(err) } // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. // 2. reset to their initial state, such that we can immediately process new incoming stream data. // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, // until UseResetMaps() has been called. func (m *streamsMap) ResetFor0RTT() { m.mutex.Lock() defer m.mutex.Unlock() m.reset = true m.CloseWithError(Err0RTTRejected) m.initMaps() } func (m *streamsMap) UseResetMaps() { m.mutex.Lock() m.reset = false m.mutex.Unlock() } quic-go-0.59.0/streams_map_incoming.go000066400000000000000000000144161513066070600177100ustar00rootroot00000000000000package quic import ( "context" "fmt" "sync" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) type incomingStream interface { closeForShutdown(error) } // When a stream is deleted before it was accepted, we can't delete it from the map immediately. // We need to wait until the application accepts it, and delete it then. type incomingStreamEntry[T incomingStream] struct { stream T shouldDelete bool } type incomingStreamsMap[T incomingStream] struct { mutex sync.RWMutex newStreamChan chan struct{} streamType protocol.StreamType streams map[protocol.StreamID]incomingStreamEntry[T] nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream() nextStreamToOpen protocol.StreamID // the highest stream that the peer opened maxStream protocol.StreamID // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamID) T queueMaxStreamID func(*wire.MaxStreamsFrame) closeErr error } func newIncomingStreamsMap[T incomingStream]( streamType protocol.StreamType, newStream func(protocol.StreamID) T, maxStreams uint64, queueControlFrame func(wire.Frame), pers protocol.Perspective, ) *incomingStreamsMap[T] { var nextStreamToAccept protocol.StreamID switch { case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer: nextStreamToAccept = protocol.FirstIncomingBidiStreamServer case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient: nextStreamToAccept = protocol.FirstIncomingBidiStreamClient case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer: nextStreamToAccept = protocol.FirstIncomingUniStreamServer case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient: nextStreamToAccept = protocol.FirstIncomingUniStreamClient } return &incomingStreamsMap[T]{ newStreamChan: make(chan struct{}, 1), streamType: streamType, streams: make(map[protocol.StreamID]incomingStreamEntry[T]), maxStream: protocol.StreamNum(maxStreams).StreamID(streamType, pers.Opposite()), maxNumStreams: maxStreams, newStream: newStream, nextStreamToOpen: nextStreamToAccept, nextStreamToAccept: nextStreamToAccept, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, } } func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) { // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist select { case <-m.newStreamChan: default: } m.mutex.Lock() var id protocol.StreamID var entry incomingStreamEntry[T] for { id = m.nextStreamToAccept if m.closeErr != nil { m.mutex.Unlock() return *new(T), m.closeErr } var ok bool entry, ok = m.streams[id] if ok { break } m.mutex.Unlock() select { case <-ctx.Done(): return *new(T), ctx.Err() case <-m.newStreamChan: } m.mutex.Lock() } m.nextStreamToAccept += 4 // If this stream was completed before being accepted, we can delete it now. if entry.shouldDelete { if err := m.deleteStream(id); err != nil { m.mutex.Unlock() return *new(T), err } } m.mutex.Unlock() return entry.stream, nil } func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error) { m.mutex.RLock() if id > m.maxStream { m.mutex.RUnlock() return *new(T), &qerr.TransportError{ ErrorCode: qerr.StreamLimitError, ErrorMessage: fmt.Sprintf("peer tried to open stream %d (current limit: %d)", id, m.maxStream), } } // if the num is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil if id < m.nextStreamToOpen { var s T // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. if entry, ok := m.streams[id]; ok && !entry.shouldDelete { s = entry.stream } m.mutex.RUnlock() return s, nil } m.mutex.RUnlock() m.mutex.Lock() // no need to check the two error conditions from above again // * maxStream can only increase, so if the id was valid before, it definitely is valid now // * highestStream is only modified by this function for newNum := m.nextStreamToOpen; newNum <= id; newNum += 4 { m.streams[newNum] = incomingStreamEntry[T]{stream: m.newStream(newNum)} select { case m.newStreamChan <- struct{}{}: default: } } m.nextStreamToOpen = id + 4 entry := m.streams[id] m.mutex.Unlock() return entry.stream, nil } func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() if err := m.deleteStream(id); err != nil { return &qerr.TransportError{ ErrorCode: qerr.StreamStateError, ErrorMessage: err.Error(), } } return nil } func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error { if _, ok := m.streams[id]; !ok { return fmt.Errorf("tried to delete unknown incoming stream %d", id) } // Don't delete this stream yet, if it was not yet accepted. // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. if id >= m.nextStreamToAccept { entry, ok := m.streams[id] if ok && entry.shouldDelete { return fmt.Errorf("tried to delete incoming stream %d multiple times", id) } entry.shouldDelete = true m.streams[id] = entry // can't assign to struct in map, so we need to reassign return nil } delete(m.streams, id) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { maxStream := m.nextStreamToOpen + 4*protocol.StreamID(m.maxNumStreams-uint64(len(m.streams))-1) // never send a value larger than the maximum value for a stream number if maxStream <= protocol.MaxStreamID { m.maxStream = maxStream m.queueMaxStreamID(&wire.MaxStreamsFrame{ Type: m.streamType, MaxStreamNum: m.maxStream.StreamNum(), }) } } return nil } func (m *incomingStreamsMap[T]) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err for _, entry := range m.streams { entry.stream.closeForShutdown(err) } m.mutex.Unlock() close(m.newStreamChan) } quic-go-0.59.0/streams_map_incoming_test.go000066400000000000000000000241621513066070600207460ustar00rootroot00000000000000package quic import ( "context" "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type mockStream struct { id protocol.StreamID closed bool closeErr error sendWindow protocol.ByteCount supportsResetStreamAt bool } func (s *mockStream) closeForShutdown(err error) { s.closed = true s.closeErr = err } func (s *mockStream) updateSendWindow(limit protocol.ByteCount) { s.sendWindow = limit } func (s *mockStream) enableResetStreamAt() { s.supportsResetStreamAt = true } func TestStreamsMapIncomingGettingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) }) t.Run("server", func(t *testing.T) { testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer) }) } func testStreamsMapIncomingGettingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { var newStreamCounter int const maxNumStreams = 10 m := newIncomingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { newStreamCounter++ return &mockStream{id: id} }, maxNumStreams, func(f wire.Frame) {}, perspective, ) // all streams up to the id on GetOrOpenStream are opened str, err := m.GetOrOpenStream(firstStream + 4) require.NoError(t, err) require.NotNil(t, str) require.Equal(t, 2, newStreamCounter) require.Equal(t, firstStream+4, str.id) // accept one of the streams ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err = m.AcceptStream(ctx) require.NoError(t, err) require.Equal(t, firstStream, str.id) // open some more streams str, err = m.GetOrOpenStream(firstStream + 16) require.NoError(t, err) require.Equal(t, 5, newStreamCounter) require.Equal(t, firstStream+16, str.id) // and accept all of them for i := 1; i < 5; i++ { str, err := m.AcceptStream(ctx) require.NoError(t, err) require.Equal(t, firstStream+4*protocol.StreamID(i), str.id) } _, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams - 4) require.NoError(t, err) _, err = m.GetOrOpenStream(firstStream + 4*maxNumStreams) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamLimitError}) require.ErrorContains(t, err, "peer tried to open stream") require.Equal(t, maxNumStreams, newStreamCounter) } func TestStreamsMapIncomingAcceptingStreams(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := newIncomingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, 5, func(f wire.Frame) {}, protocol.PerspectiveClient, ) // AcceptStream should respect the context ctx, cancel := context.WithCancel(context.Background()) defer cancel() errChan := make(chan error, 1) go func() { _, err := m.AcceptStream(ctx) errChan <- err }() synctest.Wait() select { case <-errChan: t.Fatal("AcceptStream should not return") default: } cancel() synctest.Wait() select { case err := <-errChan: require.Equal(t, context.Canceled, err) default: t.Fatal("timeout") } // AcceptStream should block if there are no streams available go func() { _, err := m.AcceptStream(context.Background()) errChan <- err }() synctest.Wait() select { case <-errChan: t.Fatal("AcceptStream should block") default: } _, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamClient) require.NoError(t, err) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("timeout") } }) } func TestStreamsMapIncomingDeletingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) }) t.Run("server", func(t *testing.T) { testStreamsMapIncomingDeletingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer) }) } func testStreamsMapIncomingDeletingStreams(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { var frameQueue []wire.Frame m := newIncomingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, 5, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, perspective, ) err := m.DeleteStream(firstStream + 1337*4) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, "tried to delete unknown incoming stream") s, err := m.GetOrOpenStream(firstStream + 4) require.NoError(t, err) require.NotNil(t, s) // delete the stream require.NoError(t, m.DeleteStream(firstStream+4)) require.Empty(t, frameQueue) // it's not returned by GetOrOpenStream anymore s, err = m.GetOrOpenStream(firstStream + 4) require.NoError(t, err) require.Nil(t, s) // AcceptStream still returns this stream str, err := m.AcceptStream(context.Background()) require.NoError(t, err) require.Equal(t, firstStream, str.id) require.Empty(t, frameQueue) str, err = m.AcceptStream(context.Background()) require.NoError(t, err) require.Equal(t, firstStream+4, str.id) // now the stream is deleted and new stream credit is issued require.Len(t, frameQueue, 1) require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 6}, frameQueue[0]) frameQueue = frameQueue[:0] require.NoError(t, m.DeleteStream(firstStream)) require.Len(t, frameQueue, 1) require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 7}, frameQueue[0]) } // There's a maximum number that can be encoded in a MAX_STREAMS frame. // Since the stream limit is configurable by the user, we can't rely on this number // being high enough that it will never be reached in practice. func TestStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) }) t.Run("server", func(t *testing.T) { testStreamsMapIncomingDeletingStreamsWithHighLimits(t, protocol.PerspectiveServer, protocol.FirstIncomingUniStreamServer) }) } func testStreamsMapIncomingDeletingStreamsWithHighLimits(t *testing.T, pers protocol.Perspective, firstStream protocol.StreamID) { var frameQueue []wire.Frame m := newIncomingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, uint64(protocol.MaxStreamCount-2), func(f wire.Frame) { frameQueue = append(frameQueue, f) }, pers, ) // open a bunch of streams _, err := m.GetOrOpenStream(firstStream + 16) require.NoError(t, err) // accept all streams for range 5 { _, err := m.AcceptStream(context.Background()) require.NoError(t, err) } require.Empty(t, frameQueue) require.NoError(t, m.DeleteStream(firstStream+12)) require.Len(t, frameQueue, 1) require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount - 1}, frameQueue[0], ) require.NoError(t, m.DeleteStream(firstStream+8)) require.Len(t, frameQueue, 2) require.Equal(t, &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}, frameQueue[1], ) // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent require.NoError(t, m.DeleteStream(firstStream+4)) require.NoError(t, m.DeleteStream(firstStream)) require.Len(t, frameQueue, 2) } func TestStreamsMapIncomingClosing(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := newIncomingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, 5, func(f wire.Frame) {}, protocol.PerspectiveServer, ) var streams []*mockStream _, err := m.GetOrOpenStream(protocol.FirstIncomingUniStreamServer + 8) require.NoError(t, err) for range 3 { str, err := m.AcceptStream(context.Background()) require.NoError(t, err) streams = append(streams, str) } errChan := make(chan error, 1) go func() { _, err := m.AcceptStream(context.Background()) errChan <- err }() m.CloseWithError(assert.AnError) synctest.Wait() // accepted streams should be closed for _, str := range streams { require.True(t, str.closed) require.ErrorIs(t, str.closeErr, assert.AnError) } // AcceptStream should return the error select { case err := <-errChan: require.ErrorIs(t, err, assert.AnError) default: t.Fatal("timeout") } }) } func TestStreamsMapIncomingRandomized(t *testing.T) { synctest.Test(t, func(t *testing.T) { const num = 1000 streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)] firstStream := protocol.FirstIncomingUniStreamServer if streamType == protocol.StreamTypeBidi { firstStream = protocol.FirstIncomingBidiStreamServer } m := newIncomingStreamsMap( streamType, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, num, func(f wire.Frame) {}, protocol.PerspectiveServer, ) ids := make([]protocol.StreamID, num) for i := range num { ids[i] = firstStream + 4*protocol.StreamID(i) } rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) errChan1 := make(chan error, 1) go func() { for range num { if _, err := m.AcceptStream(context.Background()); err != nil { errChan1 <- err return } } close(errChan1) }() errChan2 := make(chan error, 1) go func() { for i := range num { if _, err := m.GetOrOpenStream(ids[i]); err != nil { errChan2 <- err return } } close(errChan2) }() synctest.Wait() select { case err := <-errChan1: require.NoError(t, err) default: t.Fatal("should have accepted all streams") } select { case err := <-errChan2: require.NoError(t, err) default: t.Fatal("should have opened all streams") } }) } quic-go-0.59.0/streams_map_outgoing.go000066400000000000000000000144621513066070600177410ustar00rootroot00000000000000package quic import ( "context" "fmt" "slices" "sync" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" ) type outgoingStream interface { updateSendWindow(protocol.ByteCount) enableResetStreamAt() closeForShutdown(error) } type outgoingStreamsMap[T outgoingStream] struct { mutex sync.RWMutex streamType protocol.StreamType streams map[protocol.StreamID]T openQueue []chan struct{} nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) maxStream protocol.StreamID // the maximum stream ID we're allowed to open blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream newStream func(protocol.StreamID) T queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error } func newOutgoingStreamsMap[T outgoingStream]( streamType protocol.StreamType, newStream func(protocol.StreamID) T, queueControlFrame func(wire.Frame), pers protocol.Perspective, ) *outgoingStreamsMap[T] { var nextStream protocol.StreamID switch { case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer: nextStream = protocol.FirstOutgoingBidiStreamServer case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient: nextStream = protocol.FirstOutgoingBidiStreamClient case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer: nextStream = protocol.FirstOutgoingUniStreamServer case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient: nextStream = protocol.FirstOutgoingUniStreamClient } return &outgoingStreamsMap[T]{ streamType: streamType, streams: make(map[protocol.StreamID]T), maxStream: protocol.InvalidStreamNum, nextStream: nextStream, newStream: newStream, queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, } } func (m *outgoingStreamsMap[T]) OpenStream() (T, error) { m.mutex.Lock() defer m.mutex.Unlock() if m.closeErr != nil { return *new(T), m.closeErr } // if there are OpenStreamSync calls waiting, return an error here if len(m.openQueue) > 0 || m.nextStream > m.maxStream { m.maybeSendBlockedFrame() return *new(T), &StreamLimitReachedError{} } return m.openStream(), nil } func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { m.mutex.Lock() defer m.mutex.Unlock() if m.closeErr != nil { return *new(T), m.closeErr } if err := ctx.Err(); err != nil { return *new(T), err } if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { return m.openStream(), nil } waitChan := make(chan struct{}, 1) m.openQueue = append(m.openQueue, waitChan) m.maybeSendBlockedFrame() for { m.mutex.Unlock() select { case <-ctx.Done(): m.mutex.Lock() m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool { return c == waitChan }) // If we just received a MAX_STREAMS frame, this might have been the next stream // that could be opened. Make sure we unblock the next OpenStreamSync call. m.maybeUnblockOpenSync() return *new(T), ctx.Err() case <-waitChan: } m.mutex.Lock() if m.closeErr != nil { return *new(T), m.closeErr } if m.nextStream > m.maxStream { // no stream available. Continue waiting continue } str := m.openStream() m.openQueue = m.openQueue[1:] m.maybeUnblockOpenSync() return str, nil } } func (m *outgoingStreamsMap[T]) openStream() T { s := m.newStream(m.nextStream) m.streams[m.nextStream] = s m.nextStream += 4 return s } // maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, // if we haven't sent one for this offset yet func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() { if m.blockedSent { return } var streamLimit protocol.StreamNum if m.maxStream != protocol.InvalidStreamID { streamLimit = m.maxStream.StreamNum() } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ Type: m.streamType, StreamLimit: streamLimit, }) m.blockedSent = true } func (m *outgoingStreamsMap[T]) GetStream(id protocol.StreamID) (T, error) { m.mutex.RLock() if id >= m.nextStream { m.mutex.RUnlock() return *new(T), &qerr.TransportError{ ErrorCode: qerr.StreamStateError, ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), } } s := m.streams[id] m.mutex.RUnlock() return s, nil } func (m *outgoingStreamsMap[T]) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() if _, ok := m.streams[id]; !ok { return &qerr.TransportError{ ErrorCode: qerr.StreamStateError, ErrorMessage: fmt.Sprintf("tried to delete unknown outgoing stream %d", id), } } delete(m.streams, id) return nil } func (m *outgoingStreamsMap[T]) SetMaxStream(id protocol.StreamID) { m.mutex.Lock() defer m.mutex.Unlock() if id <= m.maxStream { return } m.maxStream = id m.blockedSent = false if m.maxStream < m.nextStream-4+4*protocol.StreamID(len(m.openQueue)) { m.maybeSendBlockedFrame() } m.maybeUnblockOpenSync() } // UpdateSendWindow is called when the peer's transport parameters are received. // Only in the case of a 0-RTT handshake will we have open streams at this point. // We might need to update the send window, in case the server increased it. func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) { m.mutex.Lock() for _, str := range m.streams { str.updateSendWindow(limit) } m.mutex.Unlock() } func (m *outgoingStreamsMap[T]) EnableResetStreamAt() { m.mutex.Lock() for _, str := range m.streams { str.enableResetStreamAt() } m.mutex.Unlock() } // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream func (m *outgoingStreamsMap[T]) maybeUnblockOpenSync() { if len(m.openQueue) == 0 { return } if m.nextStream > m.maxStream { return } // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. // It's sufficient to only unblock OpenStreamSync once. select { case m.openQueue[0] <- struct{}{}: default: } } func (m *outgoingStreamsMap[T]) CloseWithError(err error) { m.mutex.Lock() defer m.mutex.Unlock() m.closeErr = err for _, str := range m.streams { str.closeForShutdown(err) } for _, c := range m.openQueue { if c != nil { close(c) } } m.openQueue = nil } quic-go-0.59.0/streams_map_outgoing_test.go000066400000000000000000000365071513066070600210040ustar00rootroot00000000000000package quic import ( "context" "errors" "fmt" "math/rand/v2" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStreamsMapOutgoingOpenAndDelete(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveClient, protocol.FirstOutgoingBidiStreamClient) }) t.Run("server", func(t *testing.T) { testStreamsMapOutgoingOpenAndDelete(t, protocol.PerspectiveServer, protocol.FirstOutgoingBidiStreamServer) }) } func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { m := newOutgoingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, perspective, ) m.SetMaxStream(protocol.MaxStreamID) _, err := m.GetStream(firstStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstStream)) str1, err := m.OpenStream() require.NoError(t, err) require.Equal(t, firstStream, str1.id) s, err := m.GetStream(firstStream) require.NoError(t, err) require.Equal(t, s, str1) str2, err := m.OpenStream() require.NoError(t, err) require.Equal(t, firstStream+4, str2.id) // update send window m.UpdateSendWindow(1000) require.Equal(t, protocol.ByteCount(1000), str1.sendWindow) require.Equal(t, protocol.ByteCount(1000), str2.sendWindow) // enable reset stream at m.EnableResetStreamAt() require.True(t, str1.supportsResetStreamAt) require.True(t, str2.supportsResetStreamAt) err = m.DeleteStream(firstStream + 1337*4) require.Error(t, err) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, "tried to delete unknown outgoing stream") require.NoError(t, m.DeleteStream(firstStream)) // deleting the same stream twice will fail err = m.DeleteStream(firstStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, "tried to delete unknown outgoing stream") // after deleting the stream it's not available anymore str, err := m.GetStream(firstStream) require.NoError(t, err) require.Nil(t, str) } func TestStreamsMapOutgoingLimits(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapOutgoingLimits(t, protocol.PerspectiveClient, protocol.FirstOutgoingUniStreamClient) }) t.Run("server", func(t *testing.T) { testStreamsMapOutgoingLimits(t, protocol.PerspectiveServer, protocol.FirstOutgoingUniStreamServer) }) } func testStreamsMapOutgoingLimits(t *testing.T, perspective protocol.Perspective, firstStream protocol.StreamID) { synctest.Test(t, func(t *testing.T) { m := newOutgoingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, perspective, ) m.SetMaxStream(firstStream) str, err := m.OpenStream() require.NoError(t, err) require.Equal(t, firstStream, str.id) // We've now reached the limit. OpenStream returns an error _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) // OpenStreamSync with a canceled context will return an error immediately ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = m.OpenStreamSync(ctx) require.ErrorIs(t, err, context.Canceled) // OpenStreamSync blocks until the context is canceled... ctx, cancel = context.WithCancel(context.Background()) errChan := make(chan error, 1) go func() { _, err := m.OpenStreamSync(ctx) errChan <- err }() synctest.Wait() select { case <-errChan: t.Fatal("didn't expect OpenStreamSync to return") default: } // OpenStream still returns an error _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) // cancelling the context unblocks OpenStreamSync cancel() synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, context.Canceled) default: t.Fatal("OpenStreamSync did not return after the context was canceled") } // ... or until it's possible to open a new stream var openedStream *mockStream go func() { str, err := m.OpenStreamSync(context.Background()) openedStream = str errChan <- err }() m.SetMaxStream(firstStream + 4) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) require.Equal(t, firstStream+4, openedStream.id) default: t.Fatal("OpenStreamSync did not return after the stream limit was increased") } }) } func TestStreamsMapOutgoingConcurrentOpenStreamSync(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := newOutgoingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, protocol.PerspectiveClient, ) type result struct { index int stream *mockStream err error } results := make(chan result, 3) for i := range 3 { go func(i int) { str, err := m.OpenStreamSync(context.Background()) results <- result{index: i, stream: str, err: err} }(i) time.Sleep(time.Minute) } m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 4) synctest.Wait() received := make(map[protocol.StreamID]struct{}) for range 2 { select { case res := <-results: require.NoError(t, res.err) require.Equal(t, protocol.FirstOutgoingUniStreamClient+4*protocol.StreamID(res.index), res.stream.id) received[res.stream.id] = struct{}{} default: t.Fatal("OpenStreamSync did not return after the stream limit was increased") } } require.Contains(t, received, protocol.FirstOutgoingUniStreamClient) require.Contains(t, received, protocol.FirstOutgoingUniStreamClient+4) // the call to stream 3 is still blocked select { case <-results: t.Fatal("expected OpenStreamSync to be blocked") default: } m.SetMaxStream(protocol.FirstOutgoingUniStreamClient + 8) synctest.Wait() select { case res := <-results: require.NoError(t, res.err) require.Equal(t, protocol.FirstOutgoingUniStreamClient+8, res.stream.id) default: t.Fatal("OpenStreamSync did not return after the stream limit was increased") } }) } func TestStreamsMapOutgoingClosing(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := newOutgoingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, protocol.PerspectiveServer, ) m.SetMaxStream(protocol.FirstOutgoingUniStreamServer + 4) str1, err := m.OpenStream() require.NoError(t, err) str2, err := m.OpenStream() require.NoError(t, err) errChan := make(chan error, 1) go func() { _, err := m.OpenStreamSync(context.Background()) errChan <- err }() m.CloseWithError(assert.AnError) synctest.Wait() // both stream should be closed assert.True(t, str1.closed) assert.Equal(t, assert.AnError, str1.closeErr) assert.True(t, str2.closed) assert.Equal(t, assert.AnError, str2.closeErr) select { case err := <-errChan: require.Error(t, err) default: t.Fatal("OpenStreamSync did not return after the stream was closed") } }) } func TestStreamsMapOutgoingBlockedFrames(t *testing.T) { synctest.Test(t, func(t *testing.T) { var frameQueue []wire.Frame m := newOutgoingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, protocol.PerspectiveClient, ) m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 8) for range 3 { _, err := m.OpenStream() require.NoError(t, err) } require.Empty(t, frameQueue) _, err := m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) require.Equal(t, []wire.Frame{ &wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 3}, }, frameQueue) frameQueue = frameQueue[:0] // only a single STREAMS_BLOCKED frame is queued per offset for range 5 { _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) require.Empty(t, frameQueue) } errChan := make(chan error, 3) for range 3 { go func() { _, err := m.OpenStreamSync(context.Background()) errChan <- err }() } synctest.Wait() // allow 2 more streams m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 16) synctest.Wait() for range 2 { select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("OpenStreamSync did not return after the stream limit was increased") } } require.Equal(t, []wire.Frame{ &wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 5}, }, frameQueue) frameQueue = frameQueue[:0] // now accept the last stream m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 20) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: t.Fatal("OpenStreamSync did not return after the stream limit was increased") } require.Empty(t, frameQueue) }) } func TestStreamsMapOutgoingRandomizedOpenStreamSync(t *testing.T) { synctest.Test(t, func(t *testing.T) { streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)] firstStream := protocol.FirstOutgoingUniStreamServer if streamType == protocol.StreamTypeBidi { firstStream = protocol.FirstOutgoingBidiStreamServer } const n = 100 frameQueue := make(chan wire.Frame, n) m := newOutgoingStreamsMap( streamType, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) { frameQueue <- f }, protocol.PerspectiveServer, ) type result struct { id protocol.StreamID err error } resultChan := make(chan result, n) for range n { go func() { str, err := m.OpenStreamSync(context.Background()) resultChan <- result{id: str.id, err: err} }() } synctest.Wait() select { case f := <-frameQueue: require.IsType(t, &wire.StreamsBlockedFrame{}, f) require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit) default: t.Fatal("timed out waiting for STREAMS_BLOCKED frame") } limit := firstStream - 4 var limits []protocol.StreamID seen := make(map[protocol.StreamID]struct{}) maxStream := firstStream + 4*(n-1) for limit < maxStream { add := 4 * protocol.StreamID(rand.IntN(n/5)+1) limit += add if limit <= maxStream { limits = append(limits, limit) } t.Logf("setting stream limit to %d", limit) m.SetMaxStream(limit) synctest.Wait() loop: for { select { case res := <-resultChan: require.NoError(t, res.err) require.NotContains(t, seen, res.id) require.LessOrEqual(t, res.id, limit) seen[res.id] = struct{}{} if len(seen) == int(limit.StreamNum()) || len(seen) == n { break loop } default: t.Fatalf("timed out waiting for stream to open") } } str, err := m.OpenStream() if limit <= maxStream { require.ErrorIs(t, err, &StreamLimitReachedError{}) } else { require.NoError(t, err) require.Equal(t, maxStream+4, str.id) } } require.Len(t, seen, n) close(frameQueue) var blockedAt []protocol.StreamID for f := range frameQueue { if l := f.(*wire.StreamsBlockedFrame).StreamLimit; l <= n { blockedAt = append(blockedAt, l.StreamID(streamType, protocol.PerspectiveServer)) } } require.Equal(t, limits, blockedAt) }) } func TestStreamsMapOutgoingRandomizedWithCancellation(t *testing.T) { synctest.Test(t, func(t *testing.T) { const n = 100 streamType := []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi}[rand.IntN(2)] firstStream := protocol.FirstOutgoingUniStreamClient if streamType == protocol.StreamTypeBidi { firstStream = protocol.FirstOutgoingBidiStreamClient } frameQueue := make(chan wire.Frame, n) m := newOutgoingStreamsMap( streamType, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) { frameQueue <- f }, protocol.PerspectiveClient, ) type result struct { str *mockStream err error } ctx, cancel := context.WithCancel(context.Background()) resultChan := make(chan result, 10*n) var count int var numCancelled int for count < n { shouldCancel := rand.IntN(n)%5 == 0 if shouldCancel { numCancelled++ } else { count++ } go func() { var str *mockStream var err error if shouldCancel { str, err = m.OpenStreamSync(ctx) } else { str, err = m.OpenStreamSync(context.Background()) } resultChan <- result{str: str, err: err} }() } synctest.Wait() select { case f := <-frameQueue: require.IsType(t, &wire.StreamsBlockedFrame{}, f) require.Zero(t, f.(*wire.StreamsBlockedFrame).StreamLimit) default: t.Fatal("timed out waiting for STREAMS_BLOCKED frame") } synctest.Wait() cancel() limit := firstStream - 4 maxStream := firstStream + 4*(n-1) var limits []protocol.StreamID seen := make(map[protocol.StreamID]struct{}) var lastStreamSeen protocol.StreamID var numCancelledSeen int for limit < maxStream { add := 4 * protocol.StreamID(rand.IntN(n/5)+1) limit += add if limit < maxStream { limits = append(limits, limit) } t.Logf("setting stream limit to %d", limit) m.SetMaxStream(limit) for lastStreamSeen < min(maxStream, limit) { select { case res := <-resultChan: if errors.Is(res.err, context.Canceled) { numCancelledSeen++ } else { require.NoError(t, res.err) require.NotContains(t, seen, res.str.id) seen[res.str.id] = struct{}{} lastStreamSeen = res.str.id } case <-time.After(time.Second): t.Fatalf("timed out waiting for stream to open") } } } require.Len(t, seen, n) t.Logf("saw %d streams, %d cancelled", len(seen), numCancelledSeen) require.Equal(t, numCancelled, numCancelledSeen) close(frameQueue) var blockedAt []protocol.StreamID for f := range frameQueue { sbf := f.(*wire.StreamsBlockedFrame) require.Equal(t, streamType, sbf.Type) blockedAt = append(blockedAt, sbf.StreamLimit.StreamID(streamType, protocol.PerspectiveClient)) } require.Equal(t, limits, blockedAt) }) } func TestStreamsMapConcurrent(t *testing.T) { for i := range 5 { t.Run(fmt.Sprintf("iteration %d", i+1), func(t *testing.T) { testStreamsMapConcurrent(t) }) } } func testStreamsMapConcurrent(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := newOutgoingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *mockStream { return &mockStream{id: id} }, func(f wire.Frame) {}, protocol.PerspectiveClient, ) const num = 100 ctx, cancel := context.WithCancel(context.Background()) defer cancel() errChan := make(chan error, num) for range num { go func() { _, err := m.OpenStreamSync(ctx) errChan <- err }() } go m.CloseWithError(assert.AnError) go cancel() go m.SetMaxStream(protocol.FirstOutgoingBidiStreamClient + 4*num/2) synctest.Wait() for range num { select { case err := <-errChan: if err != nil { require.True(t, errors.Is(err, assert.AnError) || errors.Is(err, context.Canceled)) } default: t.Fatal("OpenStreamSync should have returned") } } }) } quic-go-0.59.0/streams_map_test.go000066400000000000000000000534101513066070600170610ustar00rootroot00000000000000package quic import ( "context" "fmt" "testing" "time" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func TestStreamsMapCreatingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapCreatingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingBidiStreamClient, protocol.FirstOutgoingBidiStreamClient, protocol.FirstIncomingUniStreamClient, protocol.FirstOutgoingUniStreamClient, ) }) t.Run("server", func(t *testing.T) { testStreamsMapCreatingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingBidiStreamServer, protocol.FirstOutgoingBidiStreamServer, protocol.FirstIncomingUniStreamServer, protocol.FirstOutgoingUniStreamServer, ) }) } func testStreamsMapCreatingStreams(t *testing.T, perspective protocol.Perspective, firstIncomingBidiStream protocol.StreamID, firstOutgoingBidiStream protocol.StreamID, firstIncomingUniStream protocol.StreamID, firstOutgoingUniStream protocol.StreamID, ) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) m := newStreamsMap( context.Background(), mockSender, func(wire.Frame) {}, func(protocol.StreamID) flowcontrol.StreamFlowController { fc := mocks.NewMockStreamFlowController(mockCtrl) fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() return fc }, 1, 1, perspective, ) m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: protocol.MaxStreamCount, MaxUniStreamNum: protocol.MaxStreamCount, }) // opening streams str1, err := m.OpenStream() require.NoError(t, err) str2, err := m.OpenStream() require.NoError(t, err) ustr1, err := m.OpenUniStream() require.NoError(t, err) ustr2, err := m.OpenUniStream() require.NoError(t, err) assert.Equal(t, str1.StreamID(), firstOutgoingBidiStream) assert.Equal(t, str2.StreamID(), firstOutgoingBidiStream+4) assert.Equal(t, ustr1.StreamID(), firstOutgoingUniStream) assert.Equal(t, ustr2.StreamID(), firstOutgoingUniStream+4) // accepting streams is triggered by receiving a frame referencing this stream require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, monotime.Now())) require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, monotime.Now())) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() str, err := m.AcceptStream(ctx) require.NoError(t, err) ustr, err := m.AcceptUniStream(ctx) require.NoError(t, err) assert.Equal(t, str.StreamID(), firstIncomingBidiStream) assert.Equal(t, ustr.StreamID(), firstIncomingUniStream) } func TestStreamsMapDeletingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapDeletingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingBidiStreamClient, protocol.FirstOutgoingBidiStreamClient, protocol.FirstIncomingUniStreamClient, protocol.FirstOutgoingUniStreamClient, ) }) t.Run("server", func(t *testing.T) { testStreamsMapDeletingStreams(t, protocol.PerspectiveServer, protocol.FirstIncomingBidiStreamServer, protocol.FirstOutgoingBidiStreamServer, protocol.FirstIncomingUniStreamServer, protocol.FirstOutgoingUniStreamServer, ) }) } func testStreamsMapDeletingStreams(t *testing.T, perspective protocol.Perspective, firstIncomingBidiStream protocol.StreamID, firstOutgoingBidiStream protocol.StreamID, firstIncomingUniStream protocol.StreamID, firstOutgoingUniStream protocol.StreamID, ) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) var frameQueue []wire.Frame m := newStreamsMap( context.Background(), mockSender, func(frame wire.Frame) { frameQueue = append(frameQueue, frame) }, func(protocol.StreamID) flowcontrol.StreamFlowController { fc := mocks.NewMockStreamFlowController(mockCtrl) fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() return fc }, 100, 100, perspective, ) m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: 10, MaxUniStreamNum: 10, }) _, err := m.OpenStream() require.NoError(t, err) require.NoError(t, m.DeleteStream(firstOutgoingBidiStream)) err = m.DeleteStream(firstOutgoingBidiStream + 400) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingBidiStream+400)) _, err = m.OpenUniStream() require.NoError(t, err) require.NoError(t, m.DeleteStream(firstOutgoingUniStream)) err = m.DeleteStream(firstOutgoingUniStream + 400) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown outgoing stream %d", firstOutgoingUniStream+400)) require.Empty(t, frameQueue) // deleting incoming bidirectional streams require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingBidiStream}, monotime.Now())) require.NoError(t, m.DeleteStream(firstIncomingBidiStream)) err = m.DeleteStream(firstIncomingBidiStream + 400) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingBidiStream+400)) // the MAX_STREAMS frame is only queued once the stream is accepted require.Empty(t, frameQueue) _, err = m.AcceptStream(context.Background()) require.NoError(t, err) require.Equal(t, frameQueue, []wire.Frame{ &wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreamNum: 101, }, }) frameQueue = frameQueue[:0] // deleting incoming unidirectional streams require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: firstIncomingUniStream}, monotime.Now())) require.NoError(t, m.DeleteStream(firstIncomingUniStream)) err = m.DeleteStream(firstIncomingUniStream + 400) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("tried to delete unknown incoming stream %d", firstIncomingUniStream+400)) // the MAX_STREAMS frame is only queued once the stream is accepted require.Empty(t, frameQueue) _, err = m.AcceptUniStream(context.Background()) require.NoError(t, err) require.Equal(t, frameQueue, []wire.Frame{ &wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, MaxStreamNum: 101, }, }) frameQueue = frameQueue[:0] } func TestStreamsMapStreamLimits(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapStreamLimits(t, protocol.PerspectiveClient) }) t.Run("server", func(t *testing.T) { testStreamsMapStreamLimits(t, protocol.PerspectiveServer) }) } func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) var frameQueue []wire.Frame m := newStreamsMap( context.Background(), mockSender, func(frame wire.Frame) { frameQueue = append(frameQueue, frame) }, func(protocol.StreamID) flowcontrol.StreamFlowController { fc := mocks.NewMockStreamFlowController(mockCtrl) fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes() return fc }, 100, 100, perspective, ) // increase via transport parameters _, err := m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) require.ErrorContains(t, err, "too many open streams") m.HandleTransportParameters(&wire.TransportParameters{MaxBidiStreamNum: 1}) _, err = m.OpenStream() require.NoError(t, err) _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) _, err = m.OpenUniStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) m.HandleTransportParameters(&wire.TransportParameters{MaxUniStreamNum: 1}) _, err = m.OpenUniStream() require.NoError(t, err) _, err = m.OpenUniStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) // increase via MAX_STREAMS frames m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreamNum: 2, }) _, err = m.OpenStream() require.NoError(t, err) _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, MaxStreamNum: 2, }) _, err = m.OpenUniStream() require.NoError(t, err) _, err = m.OpenUniStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) // decrease via transport parameters m.HandleTransportParameters(&wire.TransportParameters{MaxBidiStreamNum: 0}) _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) } func TestStreamsMapHandleReceiveStreamFrames(t *testing.T) { for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { t.Run(pers.String(), func(t *testing.T) { t.Run("STREAM frame", func(t *testing.T) { testStreamsMapHandleReceiveStreamFrames(t, pers, func(m *streamsMap, id protocol.StreamID) error { return m.HandleStreamFrame(&wire.StreamFrame{StreamID: id}, monotime.Now()) }, ) }) t.Run("STREAM_DATA_BLOCKED frame", func(t *testing.T) { testStreamsMapHandleReceiveStreamFrames(t, pers, func(m *streamsMap, id protocol.StreamID) error { return m.HandleStreamDataBlockedFrame(&wire.StreamDataBlockedFrame{StreamID: id}) }, ) }) t.Run("RESET_STREAM frame", func(t *testing.T) { testStreamsMapHandleReceiveStreamFrames(t, pers, func(m *streamsMap, id protocol.StreamID) error { return m.HandleResetStreamFrame(&wire.ResetStreamFrame{StreamID: id}, monotime.Now()) }, ) }) }) } } func testStreamsMapHandleReceiveStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(*streamsMap, protocol.StreamID) error) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) var streamsCreated []protocol.StreamID m := newStreamsMap( context.Background(), mockSender, func(frame wire.Frame) {}, func(id protocol.StreamID) flowcontrol.StreamFlowController { streamsCreated = append(streamsCreated, id) fc := mocks.NewMockStreamFlowController(mockCtrl) fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() fc.EXPECT().Abandon().AnyTimes() return fc }, 100, 100, pers, ) m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount}) m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}) var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID if pers == protocol.PerspectiveClient { firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient firstIncomingUniStream = protocol.FirstIncomingUniStreamClient firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient } else { firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer firstIncomingUniStream = protocol.FirstIncomingUniStreamServer firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer } // 1. The peer can't open a unidirectional send stream... err := handleFrame(m, firstOutgoingUniStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream)) require.Empty(t, streamsCreated) // ... and a STREAM frame for a unidirectional send stream is invalid even if the stream is open. _, err = m.OpenUniStream() require.NoError(t, err) err = handleFrame(m, firstOutgoingUniStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("invalid frame for receive stream %d", firstOutgoingUniStream)) streamsCreated = streamsCreated[:0] // 2. The peer can't open a bidirectional stream initiated by us... err = handleFrame(m, firstOutgoingBidiStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream)) require.Empty(t, streamsCreated) // ... but it's valid once we have opened the stream. _, err = m.OpenStream() require.NoError(t, err) require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) streamsCreated = streamsCreated[:0] // Delayed frames for deleted streams are absorbed. require.NoError(t, m.DeleteStream(firstOutgoingBidiStream)) require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) require.Empty(t, streamsCreated) // 3. The peer can send STREAM frames for unidirectional receive streams, // as long as they're below the stream limit. require.ErrorIs(t, handleFrame(m, firstIncomingUniStream+400), &qerr.TransportError{ErrorCode: qerr.StreamLimitError}, ) require.Empty(t, streamsCreated) require.NoError(t, handleFrame(m, firstIncomingUniStream)) require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingUniStream}) streamsCreated = streamsCreated[:0] // Delayed frames for deleted streams are absorbed. require.NoError(t, m.DeleteStream(firstIncomingUniStream)) require.NoError(t, handleFrame(m, firstIncomingUniStream)) require.Empty(t, streamsCreated) // 4. The peer can send STREAM frames for bidirectional receive streams, // as long as they're below the stream limit. require.ErrorIs(t, handleFrame(m, firstIncomingBidiStream+400), &qerr.TransportError{ErrorCode: qerr.StreamLimitError}, ) require.Empty(t, streamsCreated) require.NoError(t, handleFrame(m, firstIncomingBidiStream)) require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream}) } func TestStreamsMapHandleSendStreamFrames(t *testing.T) { for _, pers := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { t.Run(pers.String(), func(t *testing.T) { t.Run("STOP_SENDING frame", func(t *testing.T) { testStreamsMapHandleSendStreamFrames(t, pers, func(m *streamsMap, id protocol.StreamID) error { return m.HandleStopSendingFrame(&wire.StopSendingFrame{StreamID: id}) }, ) }) t.Run("MAX_STREAM_DATA frame", func(t *testing.T) { testStreamsMapHandleSendStreamFrames(t, pers, func(m *streamsMap, id protocol.StreamID) error { return m.HandleMaxStreamDataFrame(&wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1000}) }, ) }) }) } } func testStreamsMapHandleSendStreamFrames(t *testing.T, pers protocol.Perspective, handleFrame func(m *streamsMap, id protocol.StreamID) error) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).AnyTimes() var streamsCreated []protocol.StreamID m := newStreamsMap( context.Background(), mockSender, func(frame wire.Frame) {}, func(id protocol.StreamID) flowcontrol.StreamFlowController { streamsCreated = append(streamsCreated, id) fc := mocks.NewMockStreamFlowController(mockCtrl) fc.EXPECT().UpdateSendWindow(gomock.Any()).AnyTimes() return fc }, 100, 100, pers, ) m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: protocol.MaxStreamCount}) m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: protocol.MaxStreamCount}) var firstOutgoingUniStream, firstOutgoingBidiStream, firstIncomingUniStream, firstIncomingBidiStream protocol.StreamID if pers == protocol.PerspectiveClient { firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamClient firstOutgoingUniStream = protocol.FirstOutgoingUniStreamClient firstIncomingUniStream = protocol.FirstIncomingUniStreamClient firstIncomingBidiStream = protocol.FirstIncomingBidiStreamClient } else { firstOutgoingBidiStream = protocol.FirstOutgoingBidiStreamServer firstOutgoingUniStream = protocol.FirstOutgoingUniStreamServer firstIncomingUniStream = protocol.FirstIncomingUniStreamServer firstIncomingBidiStream = protocol.FirstIncomingBidiStreamServer } // 1. The peer can't open a unidirectional send stream... err := handleFrame(m, firstOutgoingUniStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingUniStream)) require.Empty(t, streamsCreated) // ... but once we have opened the stream, it's valid. _, err = m.OpenUniStream() require.NoError(t, err) require.NoError(t, handleFrame(m, firstOutgoingUniStream)) streamsCreated = streamsCreated[:0] // Delayed frames for deleted streams are absorbed. require.NoError(t, m.DeleteStream(firstOutgoingUniStream)) require.NoError(t, handleFrame(m, firstOutgoingUniStream)) require.Empty(t, streamsCreated) // 2. The peer can't open a bidirectional stream initiated by us... err = handleFrame(m, firstOutgoingBidiStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("peer attempted to open stream %d", firstOutgoingBidiStream)) require.Empty(t, streamsCreated) // ... but once we have opened the stream, it's valid. _, err = m.OpenStream() require.NoError(t, err) require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) streamsCreated = streamsCreated[:0] // Delayed frames for deleted streams are absorbed. require.NoError(t, m.DeleteStream(firstOutgoingBidiStream)) require.NoError(t, handleFrame(m, firstOutgoingBidiStream)) require.Empty(t, streamsCreated) // 3. The peer can't send STOP_SENDING frames for unidirectional send streams err = handleFrame(m, firstIncomingUniStream) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) require.ErrorContains(t, err, fmt.Sprintf("invalid frame for send stream %d", firstIncomingUniStream)) require.Empty(t, streamsCreated) // 4. The peer can send STOP_SENDING frames for bidirectional receive streams iniated by itself, // as long as they're below the stream limit. require.ErrorIs(t, handleFrame(m, firstIncomingBidiStream+400), &qerr.TransportError{ErrorCode: qerr.StreamLimitError}, ) require.Empty(t, streamsCreated) require.NoError(t, handleFrame(m, firstIncomingBidiStream)) require.Equal(t, streamsCreated, []protocol.StreamID{firstIncomingBidiStream}) streamsCreated = streamsCreated[:0] // Delayed frames for deleted streams are absorbed. require.NoError(t, m.DeleteStream(firstIncomingBidiStream)) require.NoError(t, handleFrame(m, firstIncomingBidiStream)) require.Empty(t, streamsCreated) } func TestStreamsMapClosing(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) m := newStreamsMap( context.Background(), mockSender, func(wire.Frame) {}, func(protocol.StreamID) flowcontrol.StreamFlowController { return mocks.NewMockStreamFlowController(mockCtrl) }, 1, 1, protocol.PerspectiveClient, ) m.CloseWithError(assert.AnError) _, err := m.OpenStream() require.ErrorIs(t, err, assert.AnError) _, err = m.OpenUniStream() require.ErrorIs(t, err, assert.AnError) _, err = m.AcceptStream(context.Background()) require.ErrorIs(t, err, assert.AnError) _, err = m.AcceptUniStream(context.Background()) require.ErrorIs(t, err, assert.AnError) } func TestStreamsMap0RTT(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) fcBidi := mocks.NewMockStreamFlowController(mockCtrl) fcUni := mocks.NewMockStreamFlowController(mockCtrl) fcs := []flowcontrol.StreamFlowController{fcBidi, fcUni} m := newStreamsMap( context.Background(), mockSender, func(wire.Frame) {}, func(protocol.StreamID) flowcontrol.StreamFlowController { fc := fcs[0] fcs = fcs[1:] return fc }, 1, 1, protocol.PerspectiveClient, ) // restored transport parameters m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: 1, MaxUniStreamNum: 1, }) _, err := m.OpenStream() require.NoError(t, err) _, err = m.OpenUniStream() require.NoError(t, err) fcBidi.EXPECT().UpdateSendWindow(protocol.ByteCount(1234)) fcUni.EXPECT().UpdateSendWindow(protocol.ByteCount(4321)) // new transport parameters m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: 1000, InitialMaxStreamDataBidiRemote: 1234, MaxUniStreamNum: 1000, InitialMaxStreamDataUni: 4321, }) } func TestStreamsMap0RTTRejection(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) m := newStreamsMap( context.Background(), mockSender, func(wire.Frame) {}, func(protocol.StreamID) flowcontrol.StreamFlowController { fc := mocks.NewMockStreamFlowController(mockCtrl) fc.EXPECT().UpdateHighestReceived(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() return fc }, 1, 1, protocol.PerspectiveClient, ) m.ResetFor0RTT() _, err := m.OpenStream() require.ErrorIs(t, err, Err0RTTRejected) _, err = m.OpenUniStream() require.ErrorIs(t, err, Err0RTTRejected) _, err = m.AcceptStream(context.Background()) require.ErrorIs(t, err, Err0RTTRejected) _, err = m.AcceptUniStream(context.Background()) require.ErrorIs(t, err, Err0RTTRejected) // make sure that we can still get new streams, as the server might be sending us data require.NoError(t, m.HandleStreamFrame(&wire.StreamFrame{StreamID: 3}, monotime.Now())) // now switch to using the new streams map m.UseResetMaps() _, err = m.OpenStream() require.Error(t, err) require.ErrorIs(t, err, &StreamLimitReachedError{}) } quic-go-0.59.0/sys_conn.go000066400000000000000000000103641513066070600153430ustar00rootroot00000000000000package quic import ( "io" "log" "net" "os" "strconv" "strings" "syscall" "time" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) type connCapabilities struct { // This connection has the Don't Fragment (DF) bit set. // This means it makes to run DPLPMTUD. DF bool // GSO (Generic Segmentation Offload) supported GSO bool // ECN (Explicit Congestion Notifications) supported ECN bool } // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (receivedPacket, error) // WritePacket writes a packet on the wire. // gsoSize is the size of a single packet, or 0 to disable GSO. // It is invalid to set gsoSize if capabilities.GSO is not set. WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer capabilities() connCapabilities } // OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. // If the PacketConn passed to the [Transport] satisfies this interface, quic-go will use it. // In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. type OOBCapablePacketConn interface { net.PacketConn SyscallConn() (syscall.RawConn, error) SetReadBuffer(int) error ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) } var _ OOBCapablePacketConn = &net.UDPConn{} func wrapConn(pc net.PacketConn) (rawConn, error) { if err := setReceiveBuffer(pc); err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { setBufferWarningOnce.Do(func() { if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { return } log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err) }) } } if err := setSendBuffer(pc); err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { setBufferWarningOnce.Do(func() { if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { return } log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err) }) } } conn, ok := pc.(interface { SyscallConn() (syscall.RawConn, error) }) var supportsDF bool if ok { rawConn, err := conn.SyscallConn() if err != nil { return nil, err } // only set DF on UDP sockets if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { var err error supportsDF, err = setDF(rawConn) if err != nil { return nil, err } } } c, ok := pc.(OOBCapablePacketConn) if !ok { utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") return &basicConn{PacketConn: pc, supportsDF: supportsDF}, nil } return newConn(c, supportsDF) } // The basicConn is the most trivial implementation of a rawConn. // It reads a single packet from the underlying net.PacketConn. // It is used when // * the net.PacketConn is not a OOBCapablePacketConn, and // * when the OS doesn't support OOB. type basicConn struct { net.PacketConn supportsDF bool } var _ rawConn = &basicConn{} func (c *basicConn) ReadPacket() (receivedPacket, error) { buffer := getPacketBuffer() // The packet size should not exceed protocol.MaxPacketBufferSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] n, addr, err := c.ReadFrom(buffer.Data) if err != nil { return receivedPacket{}, err } return receivedPacket{ remoteAddr: addr, rcvTime: monotime.Now(), data: buffer.Data[:n], buffer: buffer, }, nil } func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) { if gsoSize != 0 { panic("cannot use GSO with a basicConn") } if ecn != protocol.ECNUnsupported { panic("cannot use ECN with a basicConn") } return c.WriteTo(b, addr) } func (c *basicConn) capabilities() connCapabilities { return connCapabilities{DF: c.supportsDF} } quic-go-0.59.0/sys_conn_buffers.go000066400000000000000000000056211513066070600170570ustar00rootroot00000000000000package quic import ( "errors" "fmt" "net" "syscall" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) //go:generate sh -c "echo '// Code generated by go generate. DO NOT EDIT.\n// Source: sys_conn_buffers.go\n' > sys_conn_buffers_write.go && sed -e 's/SetReadBuffer/SetWriteBuffer/g' -e 's/setReceiveBuffer/setSendBuffer/g' -e 's/inspectReadBuffer/inspectWriteBuffer/g' -e 's/protocol\\.DesiredReceiveBufferSize/protocol\\.DesiredSendBufferSize/g' -e 's/forceSetReceiveBuffer/forceSetSendBuffer/g' -e 's/receive buffer/send buffer/g' sys_conn_buffers.go | sed '/^\\/\\/go:generate/d' >> sys_conn_buffers_write.go" func setReceiveBuffer(c net.PacketConn) error { conn, ok := c.(interface{ SetReadBuffer(int) error }) if !ok { return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") } var syscallConn syscall.RawConn if sc, ok := c.(interface { SyscallConn() (syscall.RawConn, error) }); ok { var err error syscallConn, err = sc.SyscallConn() if err != nil { syscallConn = nil } } // The connection has a SetReadBuffer method, but we couldn't obtain a syscall.RawConn. // This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the // net.PacketConn interface and the SetReadBuffer method. // We have no way of checking if increasing the buffer size actually worked. if syscallConn == nil { return conn.SetReadBuffer(protocol.DesiredReceiveBufferSize) } size, err := inspectReadBuffer(syscallConn) if err != nil { return fmt.Errorf("failed to determine receive buffer size: %w", err) } if size >= protocol.DesiredReceiveBufferSize { utils.DefaultLogger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) return nil } // Ignore the error. We check if we succeeded by querying the buffer size afterward. _ = conn.SetReadBuffer(protocol.DesiredReceiveBufferSize) newSize, err := inspectReadBuffer(syscallConn) if newSize < protocol.DesiredReceiveBufferSize { // Try again with RCVBUFFORCE on Linux _ = forceSetReceiveBuffer(syscallConn, protocol.DesiredReceiveBufferSize) newSize, err = inspectReadBuffer(syscallConn) if err != nil { return fmt.Errorf("failed to determine receive buffer size: %w", err) } } if err != nil { return fmt.Errorf("failed to determine receive buffer size: %w", err) } if newSize == size { return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) } if newSize < protocol.DesiredReceiveBufferSize { return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) } utils.DefaultLogger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) return nil } quic-go-0.59.0/sys_conn_buffers_write.go000066400000000000000000000046571513066070600203010ustar00rootroot00000000000000// Code generated by go generate. DO NOT EDIT. // Source: sys_conn_buffers.go package quic import ( "errors" "fmt" "net" "syscall" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) func setSendBuffer(c net.PacketConn) error { conn, ok := c.(interface{ SetWriteBuffer(int) error }) if !ok { return errors.New("connection doesn't allow setting of send buffer size. Not a *net.UDPConn?") } var syscallConn syscall.RawConn if sc, ok := c.(interface { SyscallConn() (syscall.RawConn, error) }); ok { var err error syscallConn, err = sc.SyscallConn() if err != nil { syscallConn = nil } } // The connection has a SetWriteBuffer method, but we couldn't obtain a syscall.RawConn. // This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the // net.PacketConn interface and the SetWriteBuffer method. // We have no way of checking if increasing the buffer size actually worked. if syscallConn == nil { return conn.SetWriteBuffer(protocol.DesiredSendBufferSize) } size, err := inspectWriteBuffer(syscallConn) if err != nil { return fmt.Errorf("failed to determine send buffer size: %w", err) } if size >= protocol.DesiredSendBufferSize { utils.DefaultLogger.Debugf("Conn has send buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024) return nil } // Ignore the error. We check if we succeeded by querying the buffer size afterward. _ = conn.SetWriteBuffer(protocol.DesiredSendBufferSize) newSize, err := inspectWriteBuffer(syscallConn) if newSize < protocol.DesiredSendBufferSize { // Try again with RCVBUFFORCE on Linux _ = forceSetSendBuffer(syscallConn, protocol.DesiredSendBufferSize) newSize, err = inspectWriteBuffer(syscallConn) if err != nil { return fmt.Errorf("failed to determine send buffer size: %w", err) } } if err != nil { return fmt.Errorf("failed to determine send buffer size: %w", err) } if newSize == size { return fmt.Errorf("failed to increase send buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredSendBufferSize/1024, newSize/1024) } if newSize < protocol.DesiredSendBufferSize { return fmt.Errorf("failed to sufficiently increase send buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024, newSize/1024) } utils.DefaultLogger.Debugf("Increased send buffer size to %d kiB", newSize/1024) return nil } quic-go-0.59.0/sys_conn_df.go000066400000000000000000000006111513066070600160060ustar00rootroot00000000000000//go:build !linux && !windows && !darwin package quic import ( "syscall" ) func setDF(syscall.RawConn) (bool, error) { // no-op on unsupported platforms return false, nil } func isSendMsgSizeErr(err error) bool { // to be implemented for more specific platforms return false } func isRecvMsgSizeErr(err error) bool { // to be implemented for more specific platforms return false } quic-go-0.59.0/sys_conn_df_darwin.go000066400000000000000000000045741513066070600173660ustar00rootroot00000000000000//go:build darwin package quic import ( "errors" "fmt" "strconv" "strings" "syscall" "golang.org/x/sys/unix" ) // for macOS versions, see https://en.wikipedia.org/wiki/Darwin_(operating_system)#Darwin_20_onwards const ( macOSVersion11 = 20 macOSVersion15 = 24 ) func setDF(rawConn syscall.RawConn) (bool, error) { // Setting DF bit is only supported from macOS 11. // https://github.com/chromium/chromium/blob/117.0.5881.2/net/socket/udp_socket_posix.cc#L555 version, err := getMacOSVersion() if err != nil || version < macOSVersion11 { return false, err } var controlErr error var disableDF bool if err := rawConn.Control(func(fd uintptr) { addr, err := unix.Getsockname(int(fd)) if err != nil { controlErr = fmt.Errorf("getsockname: %w", err) return } // Dual-stack sockets are effectively IPv6 sockets (with IPV6_ONLY set to 0). // On macOS, the DF bit on dual-stack sockets is controlled by the IPV6_DONTFRAG option. // See https://datatracker.ietf.org/doc/draft-seemann-tsvwg-udp-fragmentation/ for details. switch addr.(type) { case *unix.SockaddrInet4: controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1) case *unix.SockaddrInet6: controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1) // Setting the DF bit on dual-stack sockets works since macOS Sequoia. // Disable DF on dual-stack sockets before Sequoia. if version < macOSVersion15 { // check if this is a dual-stack socket by reading the IPV6_V6ONLY flag v6only, err := unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY) if err != nil { controlErr = fmt.Errorf("getting IPV6_V6ONLY: %w", err) return } disableDF = v6only == 0 } default: controlErr = fmt.Errorf("unknown address type: %T", addr) } }); err != nil { return false, err } if controlErr != nil { return false, controlErr } return !disableDF, nil } func isSendMsgSizeErr(err error) bool { return errors.Is(err, unix.EMSGSIZE) } func isRecvMsgSizeErr(error) bool { return false } func getMacOSVersion() (int, error) { uname := &unix.Utsname{} if err := unix.Uname(uname); err != nil { return 0, err } release := string(uname.Release[:]) idx := strings.Index(release, ".") if idx == -1 { return 0, nil } version, err := strconv.Atoi(release[:idx]) if err != nil { return 0, err } return version, nil } quic-go-0.59.0/sys_conn_df_darwin_test.go000066400000000000000000000050231513066070600204130ustar00rootroot00000000000000package quic import ( "net" "testing" "github.com/stretchr/testify/require" ) func TestIPFragmentation(t *testing.T) { sink, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0}) require.NoError(t, err) t.Cleanup(func() { sink.Close() }) sinkPort := sink.LocalAddr().(*net.UDPAddr).Port canSendIPv4 := func(conn *net.UDPConn) bool { _, err := conn.WriteTo([]byte("hello"), &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: sinkPort}) return err == nil } canSendIPv6 := func(conn *net.UDPConn) bool { _, err := conn.WriteTo([]byte("hello"), &net.UDPAddr{IP: net.IPv6loopback, Port: sinkPort}) return err == nil } t.Run("udp4", func(t *testing.T) { conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) defer conn.Close() require.True(t, canSendIPv4(conn)) require.False(t, canSendIPv6(conn)) raw, err := conn.SyscallConn() require.NoError(t, err) canDF, _ := setDF(raw) require.True(t, canDF) }) t.Run("udp6", func(t *testing.T) { conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) require.NoError(t, err) defer conn.Close() require.False(t, canSendIPv4(conn)) require.True(t, canSendIPv6(conn)) raw, err := conn.SyscallConn() require.NoError(t, err) canDF, _ := setDF(raw) require.True(t, canDF) }) t.Run("udp, dual-stack", func(t *testing.T) { if version, err := getMacOSVersion(); err != nil || version < macOSVersion15 { t.Skipf("skipping on darwin %d", version-9) } conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0}) require.NoError(t, err) defer conn.Close() require.True(t, canSendIPv4(conn)) require.True(t, canSendIPv6(conn)) raw, err := conn.SyscallConn() require.NoError(t, err) canDF, _ := setDF(raw) require.True(t, canDF) }) t.Run("udp, listening on IPv4", func(t *testing.T) { conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) defer conn.Close() require.True(t, canSendIPv4(conn)) require.False(t, canSendIPv6(conn)) raw, err := conn.SyscallConn() require.NoError(t, err) canDF, _ := setDF(raw) require.True(t, canDF) }) t.Run("udp, listening on IPv6", func(t *testing.T) { conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) require.NoError(t, err) defer conn.Close() require.False(t, canSendIPv4(conn)) require.True(t, canSendIPv6(conn)) raw, err := conn.SyscallConn() require.NoError(t, err) canDF, _ := setDF(raw) require.True(t, canDF) }) } quic-go-0.59.0/sys_conn_df_linux.go000066400000000000000000000024121513066070600172260ustar00rootroot00000000000000//go:build linux package quic import ( "errors" "syscall" "golang.org/x/sys/unix" "github.com/quic-go/quic-go/internal/utils" ) func setDF(rawConn syscall.RawConn) (bool, error) { // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" // and the datagram will not be fragmented var errDFIPv4, errDFIPv6 error if err := rawConn.Control(func(fd uintptr) { errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE) errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE) }); err != nil { return false, err } switch { case errDFIPv4 == nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") case errDFIPv4 == nil && errDFIPv6 != nil: utils.DefaultLogger.Debugf("Setting DF for IPv4.") case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: return false, errors.New("setting DF failed for both IPv4 and IPv6") } return true, nil } func isSendMsgSizeErr(err error) bool { // https://man7.org/linux/man-pages/man7/udp.7.html return errors.Is(err, unix.EMSGSIZE) } func isRecvMsgSizeErr(error) bool { return false } quic-go-0.59.0/sys_conn_df_windows.go000066400000000000000000000031341513066070600175630ustar00rootroot00000000000000//go:build windows package quic import ( "errors" "syscall" "golang.org/x/sys/windows" "github.com/quic-go/quic-go/internal/utils" ) const ( // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAGMENT.html //nolint:stylecheck IP_DONTFRAGMENT = 14 // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html //nolint:stylecheck IPV6_DONTFRAG = 14 ) func setDF(rawConn syscall.RawConn) (bool, error) { var errDFIPv4, errDFIPv6 error if err := rawConn.Control(func(fd uintptr) { errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) }); err != nil { return false, err } switch { case errDFIPv4 == nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") case errDFIPv4 == nil && errDFIPv6 != nil: utils.DefaultLogger.Debugf("Setting DF for IPv4.") case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: return false, errors.New("setting DF failed for both IPv4 and IPv6") } return true, nil } func isSendMsgSizeErr(err error) bool { // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 return errors.Is(err, windows.WSAEMSGSIZE) } func isRecvMsgSizeErr(err error) bool { // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 return errors.Is(err, windows.WSAEMSGSIZE) } quic-go-0.59.0/sys_conn_helper_darwin.go000066400000000000000000000016211513066070600202420ustar00rootroot00000000000000//go:build darwin package quic import ( "encoding/binary" "net/netip" "syscall" "golang.org/x/sys/unix" ) const ( msgTypeIPTOS = unix.IP_RECVTOS ipv4PKTINFO = unix.IP_RECVPKTINFO ) const ecnIPv4DataLen = 4 // ReadBatch only returns a single packet on OSX, // see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. const batchSize = 1 func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ // struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_addr; /* Header Destination address */ // }; if len(body) != 12 { return netip.Addr{}, 0, false } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.NativeEndian.Uint32(body), true } func isGSOEnabled(syscall.RawConn) bool { return false } func isECNEnabled() bool { return !isECNDisabledUsingEnv() } quic-go-0.59.0/sys_conn_helper_freebsd.go000066400000000000000000000011361513066070600203710ustar00rootroot00000000000000//go:build freebsd package quic import ( "net/netip" "syscall" "golang.org/x/sys/unix" ) const ( msgTypeIPTOS = unix.IP_RECVTOS ipv4PKTINFO = 0x7 ) const ecnIPv4DataLen = 1 const batchSize = 8 func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { // struct in_pktinfo { // struct in_addr ipi_addr; /* Header Destination address */ // }; if len(body) != 4 { return netip.Addr{}, 0, false } return netip.AddrFrom4(*(*[4]byte)(body)), 0, true } func isGSOEnabled(syscall.RawConn) bool { return false } func isECNEnabled() bool { return !isECNDisabledUsingEnv() } quic-go-0.59.0/sys_conn_helper_linux.go000066400000000000000000000101711513066070600201150ustar00rootroot00000000000000//go:build linux package quic import ( "encoding/binary" "errors" "net/netip" "os" "strconv" "syscall" "unsafe" "golang.org/x/sys/unix" ) const ( msgTypeIPTOS = unix.IP_TOS ipv4PKTINFO = unix.IP_PKTINFO ) const ecnIPv4DataLen = 1 const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) var kernelVersionMajor int func init() { kernelVersionMajor, _ = kernelVersion() } func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { var serr error if err := c.Control(func(fd uintptr) { serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes) }); err != nil { return err } return serr } func forceSetSendBuffer(c syscall.RawConn, bytes int) error { var serr error if err := c.Control(func(fd uintptr) { serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bytes) }); err != nil { return err } return serr } func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ // struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_addr; /* Header Destination address */ // }; if len(body) != 12 { return netip.Addr{}, 0, false } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.NativeEndian.Uint32(body), true } // isGSOEnabled tests if the kernel supports GSO. // Sending with GSO might still fail later on, if the interface doesn't support it (see isGSOError). func isGSOEnabled(conn syscall.RawConn) bool { if kernelVersionMajor < 5 { return false } disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_GSO")) if err == nil && disabled { return false } var serr error if err := conn.Control(func(fd uintptr) { _, serr = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) }); err != nil { return false } return serr == nil } func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { startLen := len(b) const dataLen = 2 // payload is a uint16 b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) h.Level = syscall.IPPROTO_UDP h.Type = unix.UDP_SEGMENT h.SetLen(unix.CmsgLen(dataLen)) // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) *(*uint16)(unsafe.Pointer(&b[offset])) = size return b } func isGSOError(err error) bool { var serr *os.SyscallError if errors.As(err, &serr) { // EIO is returned by udp_send_skb() if the device driver does not have tx checksums enabled, // which is a hard requirement of UDP_SEGMENT. See: // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 return serr.Err == unix.EIO } return false } // The first sendmsg call on a new UDP socket sometimes errors on Linux. // It's not clear why this happens. // See https://github.com/golang/go/issues/63322. func isPermissionError(err error) bool { var serr *os.SyscallError if errors.As(err, &serr) { return serr.Syscall == "sendmsg" && serr.Err == unix.EPERM } return false } func isECNEnabled() bool { return kernelVersionMajor >= 5 && !isECNDisabledUsingEnv() } // kernelVersion returns major and minor kernel version numbers, parsed from // the syscall.Uname's Release field, or 0, 0 if the version can't be obtained // or parsed. // // copied from the standard library's internal/syscall/unix/kernel_version_linux.go func kernelVersion() (major, minor int) { var uname syscall.Utsname if err := syscall.Uname(&uname); err != nil { return } var ( values [2]int value, vi int ) for _, c := range uname.Release { if '0' <= c && c <= '9' { value = (value * 10) + int(c-'0') } else { // Note that we're assuming N.N.N here. // If we see anything else, we are likely to mis-parse it. values[vi] = value vi++ if vi >= len(values) { break } value = 0 } } return values[0], values[1] } quic-go-0.59.0/sys_conn_helper_linux_test.go000066400000000000000000000041611513066070600211560ustar00rootroot00000000000000//go:build linux package quic import ( "errors" "net" "os" "testing" "golang.org/x/sys/unix" "github.com/stretchr/testify/require" ) var ( errGSO = &os.SyscallError{Err: unix.EIO} errNotPermitted = &os.SyscallError{Syscall: "sendmsg", Err: unix.EPERM} ) func TestForcingReceiveBufferSize(t *testing.T) { if os.Getuid() != 0 { t.Skip("Must be root to force change the receive buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) defer c.Close() syscallConn, err := c.(*net.UDPConn).SyscallConn() require.NoError(t, err) const small = 256 << 10 // 256 KB require.NoError(t, forceSetReceiveBuffer(syscallConn, small)) size, err := inspectReadBuffer(syscallConn) require.NoError(t, err) // the kernel doubles this value (to allow space for bookkeeping overhead) require.Equal(t, 2*small, size) const large = 32 << 20 // 32 MB require.NoError(t, forceSetReceiveBuffer(syscallConn, large)) size, err = inspectReadBuffer(syscallConn) require.NoError(t, err) // the kernel doubles this value (to allow space for bookkeeping overhead) require.Equal(t, 2*large, size) } func TestForcingSendBufferSize(t *testing.T) { if os.Getuid() != 0 { t.Skip("Must be root to force change the send buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) defer c.Close() syscallConn, err := c.(*net.UDPConn).SyscallConn() require.NoError(t, err) const small = 256 << 10 // 256 KB require.NoError(t, forceSetSendBuffer(syscallConn, small)) size, err := inspectWriteBuffer(syscallConn) require.NoError(t, err) // the kernel doubles this value (to allow space for bookkeeping overhead) require.Equal(t, 2*small, size) const large = 32 << 20 // 32 MB require.NoError(t, forceSetSendBuffer(syscallConn, large)) size, err = inspectWriteBuffer(syscallConn) require.NoError(t, err) // the kernel doubles this value (to allow space for bookkeeping overhead) require.Equal(t, 2*large, size) } func TestGSOError(t *testing.T) { require.True(t, isGSOError(errGSO)) require.False(t, isGSOError(nil)) require.False(t, isGSOError(errors.New("test"))) } quic-go-0.59.0/sys_conn_helper_nonlinux.go000066400000000000000000000005631513066070600206340ustar00rootroot00000000000000//go:build !linux package quic func forceSetReceiveBuffer(c any, bytes int) error { return nil } func forceSetSendBuffer(c any, bytes int) error { return nil } func appendUDPSegmentSizeMsg([]byte, uint16) []byte { return nil } func isGSOError(error) bool { return false } func isPermissionError(err error) bool { return false } quic-go-0.59.0/sys_conn_helper_nonlinux_test.go000066400000000000000000000002441513066070600216670ustar00rootroot00000000000000//go:build !linux package quic import "errors" var ( errGSO = errors.New("fake GSO error") errNotPermitted = errors.New("fake not permitted error") ) quic-go-0.59.0/sys_conn_no_oob.go000066400000000000000000000007031513066070600166720ustar00rootroot00000000000000//go:build !darwin && !linux && !freebsd && !windows package quic import ( "net" "net/netip" ) func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil } func inspectReadBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil } type packetInfo struct { addr netip.Addr } func (i *packetInfo) OOB() []byte { return nil } quic-go-0.59.0/sys_conn_oob.go000066400000000000000000000233621513066070600162040ustar00rootroot00000000000000//go:build darwin || linux || freebsd package quic import ( "encoding/binary" "errors" "log" "net" "net/netip" "os" "strconv" "sync" "syscall" "unsafe" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.org/x/sys/unix" "github.com/quic-go/quic-go/internal/monotime" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) const ( ecnMask = 0x3 oobBufferSize = 128 ) // Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version. // They're both just aliases for x/net/internal/socket.Message. // This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages. var _ ipv4.Message = ipv6.Message{} type batchConn interface { ReadBatch(ms []ipv4.Message, flags int) (int, error) } func inspectReadBuffer(c syscall.RawConn) (int, error) { var size int var serr error if err := c.Control(func(fd uintptr) { size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) }); err != nil { return 0, err } return size, serr } func inspectWriteBuffer(c syscall.RawConn) (int, error) { var size int var serr error if err := c.Control(func(fd uintptr) { size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) }); err != nil { return 0, err } return size, serr } func isECNDisabledUsingEnv() bool { disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN")) return err == nil && disabled } type oobConn struct { OOBCapablePacketConn batchConn batchConn readPos uint8 // Packets received from the kernel, but not yet returned by ReadPacket(). messages []ipv4.Message buffers [batchSize]*packetBuffer cap connCapabilities } var _ rawConn = &oobConn{} func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { rawConn, err := c.SyscallConn() if err != nil { return nil, err } var needsPacketInfo bool if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() { needsPacketInfo = true } // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. // Try enabling receiving of ECN and packet info for both IP versions. // We expect at least one of those syscalls to succeed. var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error if err := rawConn.Control(func(fd uintptr) { errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) if needsPacketInfo { errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4PKTINFO, 1) errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) } }); err != nil { return nil, err } switch { case errECNIPv4 == nil && errECNIPv6 == nil: utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.") case errECNIPv4 == nil && errECNIPv6 != nil: utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.") case errECNIPv4 != nil && errECNIPv6 == nil: utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.") case errECNIPv4 != nil && errECNIPv6 != nil: return nil, errors.New("activating ECN failed for both IPv4 and IPv6") } if needsPacketInfo { switch { case errPIIPv4 == nil && errPIIPv6 == nil: utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.") case errPIIPv4 == nil && errPIIPv6 != nil: utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.") case errPIIPv4 != nil && errPIIPv6 == nil: utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.") case errPIIPv4 != nil && errPIIPv6 != nil: return nil, errors.New("activating packet info failed for both IPv4 and IPv6") } } // Allows callers to pass in a connection that already satisfies batchConn interface // to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor // via SyscallConn(), and read it that way, which might not be what the caller wants. var bc batchConn if ibc, ok := c.(batchConn); ok { bc = ibc } else { bc = ipv4.NewPacketConn(c) } msgs := make([]ipv4.Message, batchSize) for i := range msgs { // preallocate the [][]byte msgs[i].Buffers = make([][]byte, 1) } oobConn := &oobConn{ OOBCapablePacketConn: c, batchConn: bc, messages: msgs, readPos: batchSize, cap: connCapabilities{ DF: supportsDF, GSO: isGSOEnabled(rawConn), ECN: isECNEnabled(), }, } for i := 0; i < batchSize; i++ { oobConn.messages[i].OOB = make([]byte, oobBufferSize) } return oobConn, nil } var invalidCmsgOnceV4, invalidCmsgOnceV6 sync.Once func (c *oobConn) ReadPacket() (receivedPacket, error) { if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. c.messages = c.messages[:batchSize] // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call for i := uint8(0); i < c.readPos; i++ { buffer := getPacketBuffer() buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] c.buffers[i] = buffer c.messages[i].Buffers[0] = c.buffers[i].Data } c.readPos = 0 n, err := c.batchConn.ReadBatch(c.messages, 0) if n == 0 || err != nil { return receivedPacket{}, err } c.messages = c.messages[:n] } msg := c.messages[c.readPos] buffer := c.buffers[c.readPos] c.readPos++ data := msg.OOB[:msg.NN] p := receivedPacket{ remoteAddr: msg.Addr, rcvTime: monotime.Now(), data: msg.Buffers[0][:msg.N], buffer: buffer, } for len(data) > 0 { hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) if err != nil { return receivedPacket{}, err } if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: if len(body) != 1 { return receivedPacket{}, errors.New("invalid IPTOS size") } p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case ipv4PKTINFO: ip, ifIndex, ok := parseIPv4PktInfo(body) if ok { p.info.addr = ip p.info.ifIndex = ifIndex } else { invalidCmsgOnceV4.Do(func() { log.Printf("Received invalid IPv4 packet info control message: %+x. "+ "This should never occur, please open a new issue and include details about the architecture.", body) }) } } } if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: if len(body) != 4 { return receivedPacket{}, errors.New("invalid IPV6_TCLASS size") } bits := uint8(binary.NativeEndian.Uint32(body)) & ecnMask p.ecn = protocol.ParseECNHeaderBits(bits) case unix.IPV6_PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // unsigned int ipi6_ifindex; /* send/recv interface index */ // }; if len(body) == 20 { p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap() p.info.ifIndex = binary.NativeEndian.Uint32(body[16:]) } else { invalidCmsgOnceV6.Do(func() { log.Printf("Received invalid IPv6 packet info control message: %+x. "+ "This should never occur, please open a new issue and include details about the architecture.", body) }) } } } data = remainder } return p, nil } // WritePacket writes a new packet. func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) { oob := packetInfoOOB if gsoSize > 0 { if !c.capabilities().GSO { panic("GSO disabled") } oob = appendUDPSegmentSizeMsg(oob, gsoSize) } if ecn != protocol.ECNUnsupported { if !c.capabilities().ECN { panic("tried to send an ECN-marked packet although ECN is disabled") } if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { if remoteUDPAddr.IP.To4() != nil { oob = appendIPv4ECNMsg(oob, ecn) } else { oob = appendIPv6ECNMsg(oob, ecn) } } } n, _, err := c.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } func (c *oobConn) capabilities() connCapabilities { return c.cap } type packetInfo struct { addr netip.Addr ifIndex uint32 } func (info *packetInfo) OOB() []byte { if info == nil { return nil } if info.addr.Is4() { ip := info.addr.As4() // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ // struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_addr; /* Header Destination address */ // }; cm := ipv4.ControlMessage{ Src: ip[:], IfIndex: int(info.ifIndex), } return cm.Marshal() } else if info.addr.Is6() { ip := info.addr.As16() // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // unsigned int ipi6_ifindex; /* send/recv interface index */ // }; cm := ipv6.ControlMessage{ Src: ip[:], IfIndex: int(info.ifIndex), } return cm.Marshal() } return nil } func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte { startLen := len(b) b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...) h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) h.Level = syscall.IPPROTO_IP h.Type = unix.IP_TOS h.SetLen(unix.CmsgLen(ecnIPv4DataLen)) // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) b[offset] = val.ToHeaderBits() return b } func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte { startLen := len(b) const dataLen = 4 b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) h.Level = syscall.IPPROTO_IPV6 h.Type = unix.IPV6_TCLASS h.SetLen(unix.CmsgLen(dataLen)) // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) binary.NativeEndian.PutUint32(b[offset:offset+dataLen], uint32(val.ToHeaderBits())) return b } quic-go-0.59.0/sys_conn_oob_test.go000066400000000000000000000232311513066070600172360ustar00rootroot00000000000000//go:build darwin || linux || freebsd package quic import ( "fmt" "net" "testing" "time" "golang.org/x/net/ipv4" "golang.org/x/sys/unix" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) func isIPv4(ip net.IP) bool { return ip.To4() != nil } func runSysConnServer(t *testing.T, network string, addr *net.UDPAddr) (*net.UDPAddr, <-chan receivedPacket) { t.Helper() udpConn, err := net.ListenUDP(network, addr) require.NoError(t, err) t.Cleanup(func() { udpConn.Close() }) oobConn, err := newConn(udpConn, true) require.NoError(t, err) require.True(t, oobConn.capabilities().DF) packetChan := make(chan receivedPacket, 1) go func() { for { p, err := oobConn.ReadPacket() if err != nil { return } packetChan <- p } }() return udpConn.LocalAddr().(*net.UDPAddr), packetChan } // sendUDPPacketWithECN opens a new UDP socket and sends one packet with the ECN set. // It returns the local address of the socket. func sendUDPPacketWithECN(t *testing.T, network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { conn, err := net.DialUDP(network, nil, addr) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) rawConn, err := conn.SyscallConn() require.NoError(t, err) require.NoError(t, rawConn.Control(func(fd uintptr) { setECN(fd) })) _, err = conn.Write([]byte("foobar")) require.NoError(t, err) return conn.LocalAddr() } func TestReadECNFlagsIPv4(t *testing.T) { addr, packetChan := runSysConnServer(t, "udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) sentFrom := sendUDPPacketWithECN(t, "udp4", addr, func(fd uintptr) { require.NoError(t, unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 2)) }, ) select { case p := <-packetChan: require.WithinDuration(t, time.Now(), p.rcvTime.ToTime(), scaleDuration(20*time.Millisecond)) require.Equal(t, []byte("foobar"), p.data) require.Equal(t, sentFrom, p.remoteAddr) require.Equal(t, protocol.ECT0, p.ecn) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } func TestReadECNFlagsIPv6(t *testing.T) { addr, packetChan := runSysConnServer(t, "udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) sentFrom := sendUDPPacketWithECN(t, "udp6", addr, func(fd uintptr) { require.NoError(t, unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 3)) }, ) select { case p := <-packetChan: require.WithinDuration(t, time.Now(), p.rcvTime.ToTime(), scaleDuration(20*time.Millisecond)) require.Equal(t, []byte("foobar"), p.data) require.Equal(t, sentFrom, p.remoteAddr) require.Equal(t, protocol.ECNCE, p.ecn) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } func TestReadECNFlagsDualStack(t *testing.T) { addr, packetChan := runSysConnServer(t, "udp", &net.UDPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0}) // IPv4 sentFrom := sendUDPPacketWithECN(t, "udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.Port}, func(fd uintptr) { require.NoError(t, unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 3)) }, ) select { case p := <-packetChan: require.True(t, isIPv4(p.remoteAddr.(*net.UDPAddr).IP)) require.Equal(t, sentFrom.String(), p.remoteAddr.String()) require.Equal(t, protocol.ECNCE, p.ecn) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } // IPv6 sentFrom = sendUDPPacketWithECN(t, "udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: addr.Port}, func(fd uintptr) { require.NoError(t, unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 1)) }, ) select { case p := <-packetChan: require.Equal(t, sentFrom, p.remoteAddr) require.False(t, isIPv4(p.remoteAddr.(*net.UDPAddr).IP)) require.Equal(t, protocol.ECT1, p.ecn) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } func TestSendPacketsWithECNOnIPv4(t *testing.T) { addr, packetChan := runSysConnServer(t, "udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) c, err := net.ListenUDP("udp4", nil) require.NoError(t, err) defer c.Close() for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} { _, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv4ECNMsg([]byte{}, val), addr) require.NoError(t, err) select { case p := <-packetChan: require.Equal(t, []byte("foobar"), p.data) require.Equal(t, val, p.ecn) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } } func TestSendPacketsWithECNOnIPv6(t *testing.T) { addr, packetChan := runSysConnServer(t, "udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) c, err := net.ListenUDP("udp6", nil) require.NoError(t, err) defer c.Close() for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} { _, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv6ECNMsg([]byte{}, val), addr) require.NoError(t, err) select { case p := <-packetChan: require.Equal(t, []byte("foobar"), p.data) require.Equal(t, val, p.ecn) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } } func TestSysConnPacketInfoIPv4(t *testing.T) { // need to listen on 0.0.0.0, otherwise we won't get the packet info addr, packetChan := runSysConnServer(t, "udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) conn, err := net.DialUDP("udp4", nil, addr) require.NoError(t, err) defer conn.Close() _, err = conn.Write([]byte("foobar")) require.NoError(t, err) select { case p := <-packetChan: require.WithinDuration(t, time.Now(), p.rcvTime.ToTime(), scaleDuration(50*time.Millisecond)) require.Equal(t, []byte("foobar"), p.data) require.Equal(t, conn.LocalAddr(), p.remoteAddr) require.True(t, p.info.addr.IsValid()) require.True(t, isIPv4(p.info.addr.AsSlice())) require.Equal(t, net.IPv4(127, 0, 0, 1).String(), p.info.addr.String()) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } func TestSysConnPacketInfoIPv6(t *testing.T) { // need to listen on ::, otherwise we won't get the packet info addr, packetChan := runSysConnServer(t, "udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) conn, err := net.DialUDP("udp6", nil, addr) require.NoError(t, err) defer conn.Close() _, err = conn.Write([]byte("foobar")) require.NoError(t, err) select { case p := <-packetChan: require.WithinDuration(t, time.Now(), p.rcvTime.ToTime(), scaleDuration(20*time.Millisecond)) require.Equal(t, []byte("foobar"), p.data) require.Equal(t, conn.LocalAddr(), p.remoteAddr) require.NotNil(t, p.info) require.Equal(t, net.IPv6loopback, net.IP(p.info.addr.AsSlice())) case <-time.After(time.Second): t.Fatal("timeout waiting for packet") } } func TestSysConnPacketInfoDualStack(t *testing.T) { addr, packetChan := runSysConnServer(t, "udp", &net.UDPAddr{}) // IPv4 conn4, err := net.DialUDP("udp4", nil, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.Port}) require.NoError(t, err) defer conn4.Close() _, err = conn4.Write([]byte("foobar")) require.NoError(t, err) select { case p := <-packetChan: require.True(t, isIPv4(p.remoteAddr.(*net.UDPAddr).IP)) require.NotNil(t, p.info) require.True(t, p.info.addr.Is4()) require.Equal(t, net.IPv4(127, 0, 0, 1).String(), p.info.addr.String()) case <-time.After(time.Second): t.Fatal("timeout waiting for IPv4 packet") } // IPv6 conn6, err := net.DialUDP("udp6", nil, addr) require.NoError(t, err) defer conn6.Close() _, err = conn6.Write([]byte("foobar")) require.NoError(t, err) select { case p := <-packetChan: require.False(t, isIPv4(p.remoteAddr.(*net.UDPAddr).IP)) require.NotNil(t, p.info) require.Equal(t, net.IPv6loopback.String(), p.info.addr.String()) case <-time.After(time.Second): t.Fatal("timeout waiting for IPv6 packet") } } type oobRecordingConn struct { *net.UDPConn oobs [][]byte } func (c *oobRecordingConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { c.oobs = append(c.oobs, oob) return c.UDPConn.WriteMsgUDP(b, oob, addr) } type mockBatchConn struct { t *testing.T numMsgRead int callCounter int } var _ batchConn = &mockBatchConn{} func (c *mockBatchConn) ReadBatch(ms []ipv4.Message, _ int) (int, error) { require.Len(c.t, ms, batchSize) for i := 0; i < c.numMsgRead; i++ { require.Len(c.t, ms[i].Buffers, 1) require.Len(c.t, ms[i].Buffers[0], protocol.MaxPacketBufferSize) data := []byte(fmt.Sprintf("message %d", c.callCounter*c.numMsgRead+i)) ms[i].Buffers[0] = data ms[i].N = len(data) } c.callCounter++ return c.numMsgRead, nil } func TestReadsMultipleMessagesInOneBatch(t *testing.T) { bc := &mockBatchConn{t: t, numMsgRead: batchSize/2 + 1} udpConn := newUDPConnLocalhost(t) oobConn, err := newConn(udpConn, true) require.NoError(t, err) oobConn.batchConn = bc for i := 0; i < batchSize+1; i++ { p, err := oobConn.ReadPacket() require.NoError(t, err) require.Equal(t, fmt.Sprintf("message %d", i), string(p.data)) } require.Equal(t, 2, bc.callCounter) } func TestSysConnSendGSO(t *testing.T) { if !platformSupportsGSO { t.Skip("GSO not supported on this platform") } udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) c := &oobRecordingConn{UDPConn: udpConn} oobConn, err := newConn(c, true) require.NoError(t, err) require.True(t, oobConn.capabilities().GSO) oob := make([]byte, 0, 123) oobConn.WritePacket([]byte("foobar"), udpConn.LocalAddr(), oob, 3, protocol.ECNCE) require.Len(t, c.oobs, 1) oobMsg := c.oobs[0] require.NotEmpty(t, oobMsg) require.Equal(t, cap(oob), cap(oobMsg)) // check that it appended to oob expected := appendUDPSegmentSizeMsg([]byte{}, 3) // Check that the first control message is the OOB control message. require.Equal(t, expected, oobMsg[:len(expected)]) } quic-go-0.59.0/sys_conn_test.go000066400000000000000000000014761513066070600164060ustar00rootroot00000000000000package quic import ( "net" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) func TestBasicConn(t *testing.T) { mockCtrl := gomock.NewController(t) c := NewMockPacketConn(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} c.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { data := []byte("foobar") require.Equal(t, protocol.MaxPacketBufferSize, len(b)) return copy(b, data), addr, nil }) conn, err := wrapConn(c) require.NoError(t, err) p, err := conn.ReadPacket() require.NoError(t, err) require.Equal(t, []byte("foobar"), p.data) require.WithinDuration(t, time.Now(), p.rcvTime.ToTime(), scaleDuration(100*time.Millisecond)) require.Equal(t, addr, p.remoteAddr) } quic-go-0.59.0/sys_conn_windows.go000066400000000000000000000016041513066070600171120ustar00rootroot00000000000000//go:build windows package quic import ( "net/netip" "syscall" "golang.org/x/sys/windows" ) func newConn(c OOBCapablePacketConn, supportsDF bool) (*basicConn, error) { return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil } func inspectReadBuffer(c syscall.RawConn) (int, error) { var size int var serr error if err := c.Control(func(fd uintptr) { size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) }); err != nil { return 0, err } return size, serr } func inspectWriteBuffer(c syscall.RawConn) (int, error) { var size int var serr error if err := c.Control(func(fd uintptr) { size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF) }); err != nil { return 0, err } return size, serr } type packetInfo struct { addr netip.Addr } func (i *packetInfo) OOB() []byte { return nil } quic-go-0.59.0/sys_conn_windows_test.go000066400000000000000000000013401513066070600201460ustar00rootroot00000000000000//go:build windows package quic import ( "net" "testing" "github.com/stretchr/testify/require" ) func TestWindowsConn(t *testing.T) { t.Run("IPv4", func(t *testing.T) { udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) conn, err := newConn(udpConn, true) require.NoError(t, err) require.NoError(t, conn.Close()) require.True(t, conn.capabilities().DF) }) t.Run("IPv6", func(t *testing.T) { udpConn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) require.NoError(t, err) conn, err := newConn(udpConn, false) require.NoError(t, err) require.NoError(t, conn.Close()) require.False(t, conn.capabilities().DF) }) } quic-go-0.59.0/testutils/000077500000000000000000000000001513066070600152155ustar00rootroot00000000000000quic-go-0.59.0/testutils/events/000077500000000000000000000000001513066070600165215ustar00rootroot00000000000000quic-go-0.59.0/testutils/events/event_recorder.go000066400000000000000000000042261513066070600220620ustar00rootroot00000000000000package events import ( "reflect" "slices" "sync" "time" "github.com/quic-go/quic-go/qlogwriter" ) // Event is a recorded event with the event time. type Event struct { Time time.Time Event qlogwriter.Event } // Trace is a qlog.Trace that returns a qlog recorder. type Trace struct { Recorder qlogwriter.Recorder } var _ qlogwriter.Trace = &Trace{} func (t *Trace) AddProducer() qlogwriter.Recorder { return t.Recorder } func (t *Trace) SupportsSchemas(string) bool { return true } // Recorder is a qlog.Recorder that records events. // Events can be retrieved using the Events method. type Recorder struct { mx sync.Mutex events []Event } var _ qlogwriter.Recorder = &Recorder{} // Events returns all recorded events. // If filter is provided, only events of the given type(s) are returned. func (r *Recorder) RecordEvent(ev qlogwriter.Event) { r.mx.Lock() r.events = append(r.events, Event{Time: time.Now(), Event: ev}) r.mx.Unlock() } // Events returns all recorded events, including the event time. // If filter is provided, only events of the given type(s) are returned. func (r *Recorder) Events(filter ...qlogwriter.Event) []qlogwriter.Event { eventsWithTime := r.EventsWithTime(filter...) events := make([]qlogwriter.Event, 0, len(eventsWithTime)) for _, ev := range eventsWithTime { events = append(events, ev.Event) } return events } func (r *Recorder) EventsWithTime(filter ...qlogwriter.Event) []Event { r.mx.Lock() events := r.events r.mx.Unlock() if len(filter) == 0 { return events } // Some events have the same name when serialized, but use different structs. // We therefore need to filter by type, and can't use the event name. filterTypes := make([]reflect.Type, 0, len(filter)) for _, f := range filter { filterTypes = append(filterTypes, reflect.TypeOf(f)) } var filtered []Event for _, ev := range events { eventType := reflect.TypeOf(ev.Event) if slices.Contains(filterTypes, eventType) { filtered = append(filtered, ev) } } return filtered } // Clear clears the recorded events. func (r *Recorder) Clear() { r.mx.Lock() r.events = nil r.mx.Unlock() } func (r *Recorder) Close() error { return nil } quic-go-0.59.0/testutils/events/event_recorder_test.go000066400000000000000000000052521513066070600231210ustar00rootroot00000000000000package events import ( "testing" "time" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/stretchr/testify/require" ) func TestRecorder(t *testing.T) { recorder := &Recorder{} defer recorder.Close() recorder.RecordEvent(qlog.MTUUpdated{Value: 1000}) recorder.RecordEvent(qlog.ALPNInformation{ChosenALPN: "foobar"}) recorder.RecordEvent(qlog.ECNStateUpdated{State: qlog.ECNStateCapable}) recorder.RecordEvent(qlog.MTUUpdated{Value: 1200}) require.Equal(t, []qlogwriter.Event{ qlog.MTUUpdated{Value: 1000}, qlog.ALPNInformation{ChosenALPN: "foobar"}, qlog.ECNStateUpdated{State: qlog.ECNStateCapable}, qlog.MTUUpdated{Value: 1200}, }, recorder.Events(), ) require.Empty(t, recorder.Events(qlog.PacketBuffered{})) require.Equal(t, []qlogwriter.Event{ qlog.MTUUpdated{Value: 1000}, qlog.MTUUpdated{Value: 1200}, }, recorder.Events(qlog.MTUUpdated{}), ) recorder.Clear() require.Empty(t, recorder.Events()) require.Empty(t, recorder.Events(qlog.MTUUpdated{})) } func TestRecorderFilterEventsSameName(t *testing.T) { // some events have the same name when serialized, but use different structs require.Equal(t, qlog.PacketReceived{}.Name(), qlog.VersionNegotiationReceived{}.Name(), ) recorder := &Recorder{} defer recorder.Close() recorder.RecordEvent(qlog.PacketReceived{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeHandshake}, }) recorder.RecordEvent(qlog.VersionNegotiationReceived{ Header: qlog.PacketHeaderVersionNegotiation{}, SupportedVersions: []qlog.Version{0xdeadbeef, 0xdecafbad}, }) require.Equal(t, []qlogwriter.Event{ qlog.PacketReceived{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeHandshake}, }, }, recorder.Events(qlog.PacketReceived{}), ) require.Equal(t, []qlogwriter.Event{ qlog.VersionNegotiationReceived{ Header: qlog.PacketHeaderVersionNegotiation{}, SupportedVersions: []qlog.Version{0xdeadbeef, 0xdecafbad}, }, }, recorder.Events(qlog.VersionNegotiationReceived{}), ) } func TestRecorderEventsWithTime(t *testing.T) { synctest.Test(t, func(t *testing.T) { recorder := &Recorder{} start := time.Now() recorder.RecordEvent(qlog.MTUUpdated{Value: 1000}) time.Sleep(time.Minute) recorder.RecordEvent(qlog.ECNStateUpdated{State: qlog.ECNStateCapable}) time.Sleep(time.Minute) recorder.RecordEvent(qlog.MTUUpdated{Value: 1200}) require.Equal(t, []Event{ {Time: start, Event: qlog.MTUUpdated{Value: 1000}}, {Time: start.Add(2 * time.Minute), Event: qlog.MTUUpdated{Value: 1200}}, }, recorder.EventsWithTime(qlog.MTUUpdated{}), ) }) } quic-go-0.59.0/testutils/frames.go000066400000000000000000000020261513066070600170210ustar00rootroot00000000000000package testutils import "github.com/quic-go/quic-go/internal/wire" type ( Frame = wire.Frame AckFrame = wire.AckFrame ConnectionCloseFrame = wire.ConnectionCloseFrame CryptoFrame = wire.CryptoFrame DataBlockedFrame = wire.DataBlockedFrame HandshakeDoneFrame = wire.HandshakeDoneFrame MaxDataFrame = wire.MaxDataFrame MaxStreamDataFrame = wire.MaxStreamDataFrame MaxStreamsFrame = wire.MaxStreamsFrame NewConnectionIDFrame = wire.NewConnectionIDFrame NewTokenFrame = wire.NewTokenFrame PathChallengeFrame = wire.PathChallengeFrame PathResponseFrame = wire.PathResponseFrame PingFrame = wire.PingFrame ResetStreamFrame = wire.ResetStreamFrame RetireConnectionIDFrame = wire.RetireConnectionIDFrame StopSendingFrame = wire.StopSendingFrame StreamDataBlockedFrame = wire.StreamDataBlockedFrame StreamFrame = wire.StreamFrame StreamsBlockedFrame = wire.StreamsBlockedFrame ) quic-go-0.59.0/testutils/simnet/000077500000000000000000000000001513066070600165145ustar00rootroot00000000000000quic-go-0.59.0/testutils/simnet/README.md000066400000000000000000000012621513066070600177740ustar00rootroot00000000000000# simnet This package is based on @MarcoPolo's [simnet](https://github.com/marcopolo/simnet) package. A small Go library for simulating packet networks in-process. It provides drop-in `net.PacketConn` endpoints connected through configurable virtual links with latency and MTU constraints. Useful for testing networking code without sockets or root privileges. - **Drop-in API**: implements `net.PacketConn` - **Realistic links**: per-direction latency and MTU - **Packet queuing**: priority queue for scheduled packet delivery - **Routers**: perfect delivery, fixed-latency, simple firewall/NAT-like routing - **Deterministic testing**: opt-in `synctest`-based tests for time control quic-go-0.59.0/testutils/simnet/queue.go000066400000000000000000000053461513066070600201770ustar00rootroot00000000000000package simnet import ( "container/heap" "sync" "time" ) // queue is a priority queue that delivers packets at their scheduled delivery time type queue struct { mu sync.Mutex packets packetHeap newPacket chan struct{} closed bool pushCount int } func newQueue() *queue { q := &queue{ newPacket: make(chan struct{}, 1), } heap.Init(&q.packets) return q } // Enqueue adds a packet to the queue func (q *queue) Enqueue(p *packetWithDeliveryTime) { q.mu.Lock() defer q.mu.Unlock() if q.closed { return } q.pushCount++ heap.Push(&q.packets, packetWithDeliveryTimeAndOrder{packetWithDeliveryTime: p, count: q.pushCount}) // Signal that a new packet arrived (non-blocking) select { case q.newPacket <- struct{}{}: default: } } // Dequeue removes and returns the next packet when it's ready for delivery // This blocks until a packet is available AND its delivery time has been reached // Uses a timer that can be reset if a packet with earlier delivery time arrives func (q *queue) Dequeue() (*packetWithDeliveryTime, bool) { timer := time.NewTimer(time.Hour) timer.Stop() for { q.mu.Lock() if q.closed { q.mu.Unlock() timer.Stop() return nil, false } if len(q.packets) == 0 { // no packets, wait for one to arrive q.mu.Unlock() <-q.newPacket timer.Stop() continue } earliest := q.packets[0] earliestTime := earliest.DeliveryTime now := time.Now() if now.Before(earliestTime) { // not ready yet, wait until delivery time or new packet waitDuration := earliestTime.Sub(now) timer.Reset(waitDuration) q.mu.Unlock() select { case <-timer.C: continue case <-q.newPacket: // new packet arrived, might have earlier delivery time timer.Stop() continue } } // Packet is ready, remove from queue and return it po := heap.Pop(&q.packets).(packetWithDeliveryTimeAndOrder) p := po.packetWithDeliveryTime q.mu.Unlock() return p, true } } // Close closes the queue func (q *queue) Close() { q.mu.Lock() defer q.mu.Unlock() q.closed = true close(q.newPacket) } type packetWithDeliveryTimeAndOrder struct { count int *packetWithDeliveryTime } // packetHeap implements heap.Interface ordered by packet delivery time. type packetHeap []packetWithDeliveryTimeAndOrder func (h packetHeap) Len() int { return len(h) } func (h packetHeap) Less(i, j int) bool { return (h[i].DeliveryTime.Before(h[j].DeliveryTime) || h[i].DeliveryTime.Equal(h[j].DeliveryTime) && h[i].count < h[j].count) } func (h packetHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *packetHeap) Push(x any) { *h = append(*h, x.(packetWithDeliveryTimeAndOrder)) } func (h *packetHeap) Pop() any { old := *h n := len(old) item := old[n-1] *h = old[:n-1] return item } quic-go-0.59.0/testutils/simnet/queue_test.go000066400000000000000000000040271513066070600212310ustar00rootroot00000000000000package simnet import ( "testing" "time" "github.com/stretchr/testify/require" ) func TestQueue(t *testing.T) { q := newQueue() baseTime := time.Now() // Enqueue 5 packets with different delivery times // Two packets scheduled for the same time (t2) p1 := &packetWithDeliveryTime{ Packet: Packet{Data: []byte("packet1")}, DeliveryTime: baseTime.Add(10 * time.Millisecond), } p2 := &packetWithDeliveryTime{ Packet: Packet{Data: []byte("packet2")}, DeliveryTime: baseTime.Add(20 * time.Millisecond), } p3 := &packetWithDeliveryTime{ Packet: Packet{Data: []byte("packet3")}, DeliveryTime: baseTime.Add(20 * time.Millisecond), // Same time as p2 } p4 := &packetWithDeliveryTime{ Packet: Packet{Data: []byte("packet4")}, DeliveryTime: baseTime.Add(30 * time.Millisecond), } p5 := &packetWithDeliveryTime{ Packet: Packet{Data: []byte("packet5")}, DeliveryTime: baseTime.Add(5 * time.Millisecond), } // Enqueue in non-chronological order q.Enqueue(p1) q.Enqueue(p2) q.Enqueue(p3) q.Enqueue(p4) q.Enqueue(p5) // Dequeue should return packets in order: p5, p1, p2, p3, p4 // p2 and p3 have same time, but p2 was enqueued first received, ok := q.Dequeue() require.True(t, ok) require.Equal(t, "packet5", string(received.Data)) received, ok = q.Dequeue() require.True(t, ok) require.Equal(t, "packet1", string(received.Data)) received, ok = q.Dequeue() require.True(t, ok) require.Equal(t, "packet2", string(received.Data)) received, ok = q.Dequeue() require.True(t, ok) require.Equal(t, "packet3", string(received.Data)) received, ok = q.Dequeue() require.True(t, ok) require.Equal(t, "packet4", string(received.Data)) } func TestQueueClose(t *testing.T) { q := newQueue() q.Close() _, ok := q.Dequeue() require.False(t, ok) // enqueue after close should be ignored p := &packetWithDeliveryTime{ Packet: Packet{Data: []byte("packet")}, DeliveryTime: time.Now(), } q.Enqueue(p) // dequeue should still return false _, ok = q.Dequeue() require.False(t, ok) } quic-go-0.59.0/testutils/simnet/router.go000066400000000000000000000052231513066070600203650ustar00rootroot00000000000000package simnet import ( "errors" "net" "net/netip" "sync" "time" ) type ipPortKey struct { ip string port uint16 isUDP bool } func (k *ipPortKey) FromNetAddr(addr net.Addr) error { switch addr := addr.(type) { case *net.UDPAddr: *k = ipPortKey{ ip: string(addr.IP), port: uint16(addr.Port), isUDP: true, } return nil case *net.TCPAddr: *k = ipPortKey{ ip: string(addr.IP), port: uint16(addr.Port), isUDP: false, } return nil default: ip, err := netip.ParseAddrPort(addr.String()) if err != nil { return err } *k = ipPortKey{ ip: string(ip.Addr().AsSlice()), port: ip.Port(), isUDP: addr.Network() == "udp", } return nil } } type addrMap[V any] struct { mu sync.Mutex nodes map[ipPortKey]V } func (m *addrMap[V]) Get(addr net.Addr) (V, bool) { m.mu.Lock() defer m.mu.Unlock() var v V if len(m.nodes) == 0 { return v, false } var k ipPortKey if err := k.FromNetAddr(addr); err != nil { return v, false } v, ok := m.nodes[k] return v, ok } func (m *addrMap[V]) Set(addr net.Addr, v V) error { m.mu.Lock() defer m.mu.Unlock() if m.nodes == nil { m.nodes = make(map[ipPortKey]V) } var k ipPortKey if err := k.FromNetAddr(addr); err != nil { return err } m.nodes[k] = v return nil } func (m *addrMap[V]) Delete(addr net.Addr) error { m.mu.Lock() defer m.mu.Unlock() if m.nodes == nil { m.nodes = make(map[ipPortKey]V) } var k ipPortKey if err := k.FromNetAddr(addr); err != nil { return err } delete(m.nodes, k) return nil } // PerfectRouter is a router that has no latency or jitter and can route to // every node type PerfectRouter struct { nodes addrMap[PacketReceiver] } // SendPacket implements Router. func (r *PerfectRouter) SendPacket(p Packet) error { conn, ok := r.nodes.Get(p.To) if !ok { return errors.New("unknown destination") } conn.RecvPacket(p) return nil } func (r *PerfectRouter) AddNode(addr net.Addr, conn PacketReceiver) { r.nodes.Set(addr, conn) } func (r *PerfectRouter) RemoveNode(addr net.Addr) { r.nodes.Delete(addr) } var _ Router = &PerfectRouter{} type DelayedPacketReceiver struct { inner PacketReceiver delay time.Duration } func (r *DelayedPacketReceiver) RecvPacket(p Packet) { time.AfterFunc(r.delay, func() { r.inner.RecvPacket(p) }) } type FixedLatencyRouter struct { PerfectRouter latency time.Duration } func (r *FixedLatencyRouter) SendPacket(p Packet) error { return r.PerfectRouter.SendPacket(p) } func (r *FixedLatencyRouter) AddNode(addr net.Addr, conn PacketReceiver) { r.PerfectRouter.AddNode(addr, &DelayedPacketReceiver{ inner: conn, delay: r.latency, }) } var _ Router = &FixedLatencyRouter{} quic-go-0.59.0/testutils/simnet/simconn.go000066400000000000000000000122441513066070600205140ustar00rootroot00000000000000package simnet import ( "errors" "net" "slices" "sync" "sync/atomic" "time" ) var ErrDeadlineExceeded = errors.New("deadline exceeded") type PacketReceiver interface { RecvPacket(p Packet) } // Router handles routing of packets between simulated connections. // Implementations are responsible for delivering packets to their destinations. type Router interface { SendPacket(p Packet) error AddNode(addr net.Addr, receiver PacketReceiver) } type Packet struct { To net.Addr From net.Addr Data []byte } // SimConn is a simulated network connection that implements net.PacketConn. // It provides packet-based communication through a Router for testing and // simulation purposes. All send/recv operations are handled through the // Router's packet delivery mechanism. type SimConn struct { mu sync.Mutex closed bool closedChan chan struct{} deadlineUpdated chan struct{} packetsSent atomic.Uint64 packetsRcvd atomic.Uint64 bytesSent atomic.Int64 bytesRcvd atomic.Int64 router Router myAddr *net.UDPAddr myLocalAddr net.Addr packetsToRead chan Packet // Controls whether to block when receiving packets if our buffer is full. // If false, drops packets. recvBackPressure bool readDeadline time.Time writeDeadline time.Time } var _ net.PacketConn = &SimConn{} // NewSimConn creates a new simulated connection that drops packets if the // receive buffer is full. func NewSimConn(addr *net.UDPAddr, rtr Router) *SimConn { return newSimConn(addr, rtr, false) } // NewBlockingSimConn creates a new simulated connection that blocks if the // receive buffer is full. Does not drop packets. func NewBlockingSimConn(addr *net.UDPAddr, rtr Router) *SimConn { return newSimConn(addr, rtr, true) } func newSimConn(addr *net.UDPAddr, rtr Router, block bool) *SimConn { c := &SimConn{ recvBackPressure: block, router: rtr, myAddr: addr, packetsToRead: make(chan Packet, 32), closedChan: make(chan struct{}), deadlineUpdated: make(chan struct{}, 1), } rtr.AddNode(addr, c) return c } type ConnStats struct { BytesSent int BytesRcvd int PacketsSent int PacketsRcvd int } func (c *SimConn) Stats() ConnStats { return ConnStats{ BytesSent: int(c.bytesSent.Load()), BytesRcvd: int(c.bytesRcvd.Load()), PacketsSent: int(c.packetsSent.Load()), PacketsRcvd: int(c.packetsRcvd.Load()), } } // SetReadBuffer only exists to quell the warning message from quic-go func (c *SimConn) SetReadBuffer(n int) error { return nil } // SetWriteBuffer only exists to quell the warning message from quic-go func (c *SimConn) SetWriteBuffer(n int) error { return nil } func (c *SimConn) RecvPacket(p Packet) { c.mu.Lock() if c.closed { c.mu.Unlock() return } c.mu.Unlock() c.packetsRcvd.Add(1) c.bytesRcvd.Add(int64(len(p.Data))) if c.recvBackPressure { select { case c.packetsToRead <- p: case <-c.closedChan: // if the connection is closed, drop the packet return } } else { select { case c.packetsToRead <- p: default: // drop the packet if the channel is full } } } func (c *SimConn) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { return nil } c.closed = true close(c.closedChan) return nil } func (c *SimConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { c.mu.Lock() if c.closed { c.mu.Unlock() return 0, nil, net.ErrClosed } deadline := c.readDeadline c.mu.Unlock() if !deadline.IsZero() && !time.Now().Before(deadline) { return 0, nil, ErrDeadlineExceeded } var pkt Packet var deadlineTimer <-chan time.Time if !deadline.IsZero() { deadlineTimer = time.After(time.Until(deadline)) } select { case pkt = <-c.packetsToRead: case <-c.closedChan: return 0, nil, net.ErrClosed case <-c.deadlineUpdated: return c.ReadFrom(p) case <-deadlineTimer: return 0, nil, ErrDeadlineExceeded } n = copy(p, pkt.Data) // if the provided buffer is not enough to read the whole packet, we drop // the rest of the data. this is similar to what `recvfrom` does on Linux // and macOS. return n, pkt.From, nil } func (c *SimConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { c.mu.Lock() if c.closed { c.mu.Unlock() return 0, net.ErrClosed } deadline := c.writeDeadline c.mu.Unlock() if !deadline.IsZero() && !time.Now().Before(deadline) { return 0, ErrDeadlineExceeded } c.packetsSent.Add(1) c.bytesSent.Add(int64(len(p))) pkt := Packet{ From: c.myAddr, To: addr, Data: slices.Clone(p), } return len(p), c.router.SendPacket(pkt) } func (c *SimConn) UnicastAddr() net.Addr { return c.myAddr } func (c *SimConn) LocalAddr() net.Addr { if c.myLocalAddr != nil { return c.myLocalAddr } return c.myAddr } func (c *SimConn) SetDeadline(t time.Time) error { c.mu.Lock() defer c.mu.Unlock() c.readDeadline = t c.writeDeadline = t select { case c.deadlineUpdated <- struct{}{}: default: } return nil } func (c *SimConn) SetReadDeadline(t time.Time) error { c.mu.Lock() defer c.mu.Unlock() c.readDeadline = t select { case c.deadlineUpdated <- struct{}{}: default: } return nil } func (c *SimConn) SetWriteDeadline(t time.Time) error { c.mu.Lock() defer c.mu.Unlock() c.writeDeadline = t return nil } quic-go-0.59.0/testutils/simnet/simconn_test.go000066400000000000000000000113721513066070600215540ustar00rootroot00000000000000package simnet import ( "crypto/rand" "net" "sync" "testing" "time" "github.com/stretchr/testify/require" ) func randomPublicIPv4() net.IP { start: ip := make([]byte, 4) rand.Read(ip[:]) if net.IP(ip).IsPrivate() || net.IP(ip).IsLoopback() || net.IP(ip).IsLinkLocalUnicast() { goto start } return ip } func TestSimConnBasicConnectivity(t *testing.T) { router := &PerfectRouter{} // Create two endpoints addr1 := &net.UDPAddr{IP: randomPublicIPv4(), Port: 1234} addr2 := &net.UDPAddr{IP: randomPublicIPv4(), Port: 1234} conn1 := NewSimConn(addr1, router) conn2 := NewSimConn(addr2, router) // Test sending data from conn1 to conn2 testData := []byte("hello world") n, err := conn1.WriteTo(testData, addr2) require.NoError(t, err) require.Equal(t, len(testData), n) // Read data from conn2 buf := make([]byte, 1024) n, addr, err := conn2.ReadFrom(buf) require.NoError(t, err) require.Equal(t, testData, buf[:n]) require.Equal(t, addr1, addr) // Check stats stats1 := conn1.Stats() require.Equal(t, len(testData), stats1.BytesSent) require.Equal(t, 1, stats1.PacketsSent) stats2 := conn2.Stats() require.Equal(t, len(testData), stats2.BytesRcvd) require.Equal(t, 1, stats2.PacketsRcvd) } func TestSimConnDeadlines(t *testing.T) { router := &PerfectRouter{} addr1 := &net.UDPAddr{IP: randomPublicIPv4(), Port: 1234} conn := NewSimConn(addr1, router) t.Run("read deadline", func(t *testing.T) { deadline := time.Now().Add(10 * time.Millisecond) err := conn.SetReadDeadline(deadline) require.NoError(t, err) buf := make([]byte, 1024) _, _, err = conn.ReadFrom(buf) require.ErrorIs(t, err, ErrDeadlineExceeded) }) t.Run("write deadline", func(t *testing.T) { deadline := time.Now().Add(-time.Second) // Already expired err := conn.SetWriteDeadline(deadline) require.NoError(t, err) _, err = conn.WriteTo([]byte("test"), &net.UDPAddr{}) require.ErrorIs(t, err, ErrDeadlineExceeded) }) } func TestSimConnClose(t *testing.T) { router := &PerfectRouter{} addr1 := &net.UDPAddr{IP: randomPublicIPv4(), Port: 1234} conn := NewSimConn(addr1, router) err := conn.Close() require.NoError(t, err) // Verify operations fail after close _, err = conn.WriteTo([]byte("test"), addr1) require.ErrorIs(t, err, net.ErrClosed) buf := make([]byte, 1024) _, _, err = conn.ReadFrom(buf) require.ErrorIs(t, err, net.ErrClosed) // Second close should not error err = conn.Close() require.NoError(t, err) } func TestSimConnDeadlinesWithLatency(t *testing.T) { router := &FixedLatencyRouter{ PerfectRouter: PerfectRouter{}, latency: 100 * time.Millisecond, } addr1 := &net.UDPAddr{IP: randomPublicIPv4(), Port: 1234} addr2 := &net.UDPAddr{IP: randomPublicIPv4(), Port: 1234} conn1 := NewSimConn(addr1, router) conn2 := NewSimConn(addr2, router) reset := func() { router.RemoveNode(addr1) router.RemoveNode(addr2) conn1 = NewSimConn(addr1, router) conn2 = NewSimConn(addr2, router) } t.Run("write succeeds within deadline", func(t *testing.T) { deadline := time.Now().Add(200 * time.Millisecond) err := conn1.SetWriteDeadline(deadline) require.NoError(t, err) n, err := conn1.WriteTo([]byte("test"), addr2) require.NoError(t, err) require.Equal(t, 4, n) reset() }) t.Run("write fails after past deadline", func(t *testing.T) { deadline := time.Now().Add(-time.Second) // Already expired err := conn1.SetWriteDeadline(deadline) require.NoError(t, err) _, err = conn1.WriteTo([]byte("test"), addr2) require.ErrorIs(t, err, ErrDeadlineExceeded) reset() }) t.Run("read succeeds within deadline", func(t *testing.T) { // Reset deadline and send a message conn2.SetReadDeadline(time.Time{}) testData := []byte("hello") deadline := time.Now().Add(200 * time.Millisecond) conn1.SetWriteDeadline(deadline) _, err := conn1.WriteTo(testData, addr2) require.NoError(t, err) // Set read deadline and try to read deadline = time.Now().Add(200 * time.Millisecond) err = conn2.SetReadDeadline(deadline) require.NoError(t, err) buf := make([]byte, 1024) n, addr, err := conn2.ReadFrom(buf) require.NoError(t, err) require.Equal(t, addr1, addr) require.Equal(t, testData, buf[:n]) reset() }) t.Run("read fails after deadline", func(t *testing.T) { defer reset() // Set a short deadline deadline := time.Now().Add(50 * time.Millisecond) // Less than router latency err := conn2.SetReadDeadline(deadline) require.NoError(t, err) var wg sync.WaitGroup defer wg.Wait() wg.Add(1) go func() { defer wg.Done() // Send data after setting deadline _, err := conn1.WriteTo([]byte("test"), addr2) require.NoError(t, err) }() // Read should fail due to deadline buf := make([]byte, 1024) _, _, err = conn2.ReadFrom(buf) require.ErrorIs(t, err, ErrDeadlineExceeded) }) } quic-go-0.59.0/testutils/simnet/simlink.go000066400000000000000000000066261513066070600205230ustar00rootroot00000000000000package simnet import ( "net" "sync" "time" ) // packetWithDeliveryTime holds a packet along with its scheduled delivery time type packetWithDeliveryTime struct { Packet DeliveryTime time.Time } // LinkSettings defines the network characteristics for a simulated link direction type LinkSettings struct { // MTU (Maximum Transmission Unit) specifies the maximum packet size in bytes MTU int } // SimulatedLink simulates a bidirectional network link with variable latency and MTU constraints type SimulatedLink struct { // Internal state for lifecycle management wg sync.WaitGroup // Queues for packet delivery timing downstreamQueue *queue upstreamQueue *queue // Configuration for link characteristics UplinkSettings LinkSettings DownlinkSettings LinkSettings // Latency specifies a fixed network delay for downlink packets // If both Latency and LatencyFunc are set, LatencyFunc takes precedence Latency time.Duration // LatencyFunc computes the network delay for each downlink packet // This allows variable latency based on packet source/destination // If nil, Latency field is used instead LatencyFunc func(Packet) time.Duration // Packet routing interfaces UploadPacket Router downloadPacket PacketReceiver } func (l *SimulatedLink) AddNode(addr net.Addr, receiver PacketReceiver) { l.downloadPacket = receiver } func (l *SimulatedLink) Start() { if l.downloadPacket == nil { panic("SimulatedLink.Start() called without having added a packet receiver") } // Sane defaults if l.DownlinkSettings.MTU == 0 { l.DownlinkSettings.MTU = 1400 } if l.UplinkSettings.MTU == 0 { l.UplinkSettings.MTU = 1400 } l.downstreamQueue = newQueue() l.upstreamQueue = newQueue() l.wg.Add(2) go l.backgroundDownlink() go l.backgroundUplink() } func (l *SimulatedLink) Close() error { l.downstreamQueue.Close() l.upstreamQueue.Close() l.wg.Wait() return nil } func (l *SimulatedLink) backgroundDownlink() { defer l.wg.Done() for { // Dequeue a packet (this will block until packet is ready for delivery) // Dequeue() returns false when the queue is closed p, ok := l.downstreamQueue.Dequeue() if !ok { return } // Deliver the packet l.downloadPacket.RecvPacket(p.Packet) } } func (l *SimulatedLink) backgroundUplink() { defer l.wg.Done() for { // Dequeue a packet (this will block until packet is ready for delivery) // Dequeue() returns false when the queue is closed p, ok := l.upstreamQueue.Dequeue() if !ok { return } // Deliver the packet _ = l.UploadPacket.SendPacket(p.Packet) } } func (l *SimulatedLink) SendPacket(p Packet) error { if len(p.Data) > l.UplinkSettings.MTU { // Drop packet if it's too large return nil } // Uplink has no latency - packets are delivered immediately deliveryTime := time.Now() // Enqueue packet with delivery time l.upstreamQueue.Enqueue(&packetWithDeliveryTime{ Packet: p, DeliveryTime: deliveryTime, }) return nil } func (l *SimulatedLink) RecvPacket(p Packet) { if len(p.Data) > l.DownlinkSettings.MTU { // Drop packet if it's too large return } // Calculate delivery time based on downlink latency var latency time.Duration if l.LatencyFunc != nil { latency = l.LatencyFunc(p) } else { latency = l.Latency } deliveryTime := time.Now().Add(latency) // Enqueue packet with delivery time l.downstreamQueue.Enqueue(&packetWithDeliveryTime{ Packet: p, DeliveryTime: deliveryTime, }) } quic-go-0.59.0/testutils/simnet/simlink_test.go000066400000000000000000000100001513066070600215370ustar00rootroot00000000000000package simnet import ( "fmt" "math" "net" "testing" "time" "github.com/quic-go/quic-go/internal/synctest" "github.com/stretchr/testify/require" ) type testRouter struct { onSend func(p Packet) onRecv func(p Packet) } func (r *testRouter) SendPacket(p Packet) error { r.onSend(p) return nil } func (r *testRouter) RecvPacket(p Packet) { r.onRecv(p) } func (r *testRouter) AddNode(addr net.Addr, receiver PacketReceiver) { r.onRecv = receiver.RecvPacket } func TestLatency(t *testing.T) { for _, testUpload := range []bool{true, false} { t.Run(fmt.Sprintf("testing upload=%t", testUpload), func(t *testing.T) { synctest.Test(t, func(t *testing.T) { const downlinkLatency = 10 * time.Millisecond const MTU = 1400 linkSettings := LinkSettings{ MTU: MTU, } recvStartTimeChan := make(chan time.Time, 1) recvStarted := false packetHandler := func(p Packet) { if !recvStarted { recvStarted = true recvStartTimeChan <- time.Now() } } router := &testRouter{} if testUpload { router.onSend = packetHandler } else { router.onRecv = packetHandler } link := SimulatedLink{ UplinkSettings: linkSettings, DownlinkSettings: linkSettings, LatencyFunc: func(p Packet) time.Duration { return downlinkLatency }, UploadPacket: router, downloadPacket: router, } link.Start() chunk := make([]byte, MTU) sendStartTime := time.Now() if testUpload { _ = link.SendPacket(Packet{Data: chunk}) } else { link.RecvPacket(Packet{Data: chunk}) } // Wait for delayed packets to be sent time.Sleep(40 * time.Millisecond) link.Close() recvStartTime := <-recvStartTimeChan observedLatency := recvStartTime.Sub(sendStartTime) // Uplink is now instant (no latency), only downlink has latency var expectedLatency time.Duration if testUpload { // Uplink test: expect near-zero latency expectedLatency = 0 t.Logf("observed latency: %s (uplink is instant)", observedLatency) if observedLatency > 5*time.Millisecond { t.Fatalf("observed latency %s is too high for instant uplink", observedLatency) } } else { // Downlink test: expect configured latency expectedLatency = downlinkLatency percentErrorLatency := math.Abs(observedLatency.Seconds()-expectedLatency.Seconds()) / expectedLatency.Seconds() t.Logf("observed latency: %s, expected latency: %s, percent error: %f", observedLatency, expectedLatency, percentErrorLatency) if percentErrorLatency > 0.20 { t.Fatalf("observed latency %s is wrong", observedLatency) } } }) }) } } func TestMTUEnforcement(t *testing.T) { synctest.Test(t, func(t *testing.T) { const MTU = 1400 linkSettings := LinkSettings{ MTU: MTU, } packetsReceived := 0 packetHandler := func(p Packet) { packetsReceived++ } router := &testRouter{ onSend: packetHandler, onRecv: packetHandler, } link := SimulatedLink{ UplinkSettings: linkSettings, DownlinkSettings: linkSettings, UploadPacket: router, downloadPacket: router, } link.Start() // Send a packet that fits within MTU - should be delivered smallPacket := make([]byte, MTU) err := link.SendPacket(Packet{Data: smallPacket}) require.NoError(t, err) // Send a packet that exceeds MTU - should be dropped largePacket := make([]byte, MTU+1) err = link.SendPacket(Packet{Data: largePacket}) require.NoError(t, err) // SendPacket returns nil even when dropping // Receive a packet that fits within MTU - should be delivered link.RecvPacket(Packet{Data: smallPacket}) // Receive a packet that exceeds MTU - should be dropped link.RecvPacket(Packet{Data: largePacket}) // Wait for packets to be processed time.Sleep(10 * time.Millisecond) link.Close() // Only packets within MTU should be received (2 packets: 1 from SendPacket, 1 from RecvPacket) if packetsReceived != 2 { t.Fatalf("expected 2 packets to be received, got %d", packetsReceived) } }) } quic-go-0.59.0/testutils/simnet/simnet.go000066400000000000000000000035261513066070600203500ustar00rootroot00000000000000package simnet import ( "errors" "fmt" "net" "time" ) // Simnet is a simulated network that manages connections between nodes // with configurable network conditions. type Simnet struct { Router Router links []*SimulatedLink } // NodeBiDiLinkSettings defines the bidirectional link settings for a network node. // It specifies separate configurations for downlink (incoming) and uplink (outgoing) // traffic, allowing asymmetric network conditions to be simulated. type NodeBiDiLinkSettings struct { // Downlink configures the settings for incoming traffic to this node Downlink LinkSettings // Uplink configures the settings for outgoing traffic from this node Uplink LinkSettings // Latency specifies a fixed network delay for downlink packets only // If both Latency and LatencyFunc are set, LatencyFunc takes precedence Latency time.Duration // LatencyFunc computes the network delay for each downlink packet // This allows variable latency based on packet source/destination // If nil, Latency field is used instead LatencyFunc func(Packet) time.Duration } func (n *Simnet) Start() error { for _, link := range n.links { link.Start() } return nil } func (n *Simnet) Close() error { var errs error for _, link := range n.links { err := link.Close() if err != nil { errs = errors.Join(errs, err) } } if errs != nil { return fmt.Errorf("failed to close some links: %w", errs) } return nil } func (n *Simnet) NewEndpoint(addr *net.UDPAddr, linkSettings NodeBiDiLinkSettings) *SimConn { link := &SimulatedLink{ DownlinkSettings: linkSettings.Downlink, UplinkSettings: linkSettings.Uplink, Latency: linkSettings.Latency, LatencyFunc: linkSettings.LatencyFunc, UploadPacket: n.Router, } c := NewBlockingSimConn(addr, link) n.links = append(n.links, link) n.Router.AddNode(addr, link) return c } quic-go-0.59.0/testutils/simnet/simnet_synctest_test.go000066400000000000000000000031101513066070600233300ustar00rootroot00000000000000package simnet import ( "math" "net" "testing" "time" "github.com/quic-go/quic-go/internal/synctest" "github.com/stretchr/testify/require" ) func newConn(simnet *Simnet, address *net.UDPAddr, linkSettings NodeBiDiLinkSettings) *SimConn { return simnet.NewEndpoint(address, linkSettings) } func TestSimpleSimNet(t *testing.T) { synctest.Test(t, func(t *testing.T) { router := &Simnet{Router: &PerfectRouter{}} const latency = 10 * time.Millisecond linkSettings := NodeBiDiLinkSettings{ Downlink: LinkSettings{}, Uplink: LinkSettings{}, Latency: latency, } addressA := net.UDPAddr{ IP: net.ParseIP("1.0.0.1"), Port: 8000, } connA := newConn(router, &addressA, linkSettings) addressB := net.UDPAddr{ IP: net.ParseIP("1.0.0.2"), Port: 8000, } connB := newConn(router, &addressB, linkSettings) router.Start() defer router.Close() start := time.Now() connA.WriteTo([]byte("hello"), &addressB) buf := make([]byte, 1024) n, from, err := connB.ReadFrom(buf) require.NoError(t, err) require.Equal(t, "hello", string(buf[:n])) require.Equal(t, addressA.String(), from.String()) observedLatency := time.Since(start) // Only downlink has latency now (uplink is instant) expectedLatency := latency percentDiff := math.Abs(float64(observedLatency-expectedLatency) / float64(expectedLatency)) t.Logf("observed latency: %v, expected latency: %v, percent diff: %v", observedLatency, expectedLatency, percentDiff) if percentDiff > 0.30 { t.Fatalf("latency is wrong: %v. percent off: %v", observedLatency, percentDiff) } }) } quic-go-0.59.0/testutils/testutils.go000066400000000000000000000062461513066070600176140ustar00rootroot00000000000000// Package testutils contains utilities for simulating packet injection and man-in-the-middle (MITM) attacker tests. // It is not supposed to be used for non-testing purposes. // The API is not guaranteed to be stable. package testutils import ( "fmt" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) // writePacket returns a new raw packet with the specified header and payload func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { b, err := hdr.Append(nil, hdr.Version) if err != nil { panic(fmt.Sprintf("failed to write header: %s", err)) } return append(b, data...) } // packRawPayload returns a new raw payload containing given frames func packRawPayload(version protocol.Version, frames []wire.Frame) []byte { var b []byte for _, cf := range frames { var err error b, err = cf.Append(b, version) if err != nil { panic(err) } } return b } // ComposeInitialPacket returns an Initial packet encrypted under key (the original destination connection ID) // containing specified frames. func ComposeInitialPacket( srcConnID, destConnID, key protocol.ConnectionID, token []byte, frames []wire.Frame, sentBy protocol.Perspective, version protocol.Version, ) []byte { sealer, _ := handshake.NewInitialAEAD(key, sentBy, version) // compose payload var payload []byte if len(frames) == 0 { payload = make([]byte, protocol.MinInitialPacketSize) } else { payload = packRawPayload(version, frames) } // compose Initial header payloadSize := len(payload) const pnLength = protocol.PacketNumberLen4 length := payloadSize + int(pnLength) + sealer.Overhead() hdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, Token: token, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Length: protocol.ByteCount(length), Version: version, }, PacketNumberLen: pnLength, PacketNumber: 0x0, } raw := writePacket(hdr, payload) // encrypt payload and header payloadOffset := len(raw) - payloadSize var encrypted []byte encrypted = sealer.Seal(encrypted, payload, hdr.PacketNumber, raw[:payloadOffset]) hdrBytes := raw[0:payloadOffset] encrypted = append(hdrBytes, encrypted...) pnOffset := payloadOffset - int(pnLength) // packet number offset sealer.EncryptHeader( encrypted[payloadOffset:payloadOffset+16], // first 16 bytes of payload (sample) &encrypted[0], // first byte of header encrypted[pnOffset:payloadOffset], // packet number bytes ) return encrypted } // ComposeRetryPacket returns a new raw Retry Packet func ComposeRetryPacket( srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, origDestConnID protocol.ConnectionID, token []byte, version protocol.Version, ) []byte { hdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeRetry, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Token: token, Version: version, }, } data := writePacket(hdr, nil) return append(data, handshake.GetRetryIntegrityTag(data, origDestConnID, version)[:]...) } quic-go-0.59.0/token_store.go000066400000000000000000000050141513066070600160400ustar00rootroot00000000000000package quic import ( "sync" list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) type singleOriginTokenStore struct { tokens []*ClientToken len int p int } func newSingleOriginTokenStore(size int) *singleOriginTokenStore { return &singleOriginTokenStore{tokens: make([]*ClientToken, size)} } func (s *singleOriginTokenStore) Add(token *ClientToken) { s.tokens[s.p] = token s.p = s.index(s.p + 1) s.len = min(s.len+1, len(s.tokens)) } func (s *singleOriginTokenStore) Pop() *ClientToken { s.p = s.index(s.p - 1) token := s.tokens[s.p] s.tokens[s.p] = nil s.len = max(s.len-1, 0) return token } func (s *singleOriginTokenStore) Len() int { return s.len } func (s *singleOriginTokenStore) index(i int) int { mod := len(s.tokens) return (i + mod) % mod } type lruTokenStoreEntry struct { key string cache *singleOriginTokenStore } type lruTokenStore struct { mutex sync.Mutex m map[string]*list.Element[*lruTokenStoreEntry] q *list.List[*lruTokenStoreEntry] capacity int singleOriginSize int } var _ TokenStore = &lruTokenStore{} // NewLRUTokenStore creates a new LRU cache for tokens received by the client. // maxOrigins specifies how many origins this cache is saving tokens for. // tokensPerOrigin specifies the maximum number of tokens per origin. func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore { return &lruTokenStore{ m: make(map[string]*list.Element[*lruTokenStoreEntry]), q: list.New[*lruTokenStoreEntry](), capacity: maxOrigins, singleOriginSize: tokensPerOrigin, } } func (s *lruTokenStore) Put(key string, token *ClientToken) { s.mutex.Lock() defer s.mutex.Unlock() if el, ok := s.m[key]; ok { entry := el.Value entry.cache.Add(token) s.q.MoveToFront(el) return } if s.q.Len() < s.capacity { entry := &lruTokenStoreEntry{ key: key, cache: newSingleOriginTokenStore(s.singleOriginSize), } entry.cache.Add(token) s.m[key] = s.q.PushFront(entry) return } elem := s.q.Back() entry := elem.Value delete(s.m, entry.key) entry.key = key entry.cache = newSingleOriginTokenStore(s.singleOriginSize) entry.cache.Add(token) s.q.MoveToFront(elem) s.m[key] = elem } func (s *lruTokenStore) Pop(key string) *ClientToken { s.mutex.Lock() defer s.mutex.Unlock() var token *ClientToken if el, ok := s.m[key]; ok { s.q.MoveToFront(el) cache := el.Value.cache token = cache.Pop() if cache.Len() == 0 { s.q.Remove(el) delete(s.m, key) } } return token } quic-go-0.59.0/token_store_test.go000066400000000000000000000043501513066070600171010ustar00rootroot00000000000000package quic import ( "fmt" "testing" "time" "github.com/stretchr/testify/require" ) func mockToken(num int) *ClientToken { return &ClientToken{data: []byte(fmt.Sprintf("%d", num)), rtt: 1337 * time.Millisecond} } func TestTokenStoreSingleOrigin(t *testing.T) { const origin = "localhost" s := NewLRUTokenStore(1, 3) s.Put(origin, mockToken(1)) s.Put(origin, mockToken(2)) require.Equal(t, mockToken(2), s.Pop(origin)) require.Equal(t, mockToken(1), s.Pop(origin)) require.Nil(t, s.Pop(origin)) // now add more tokens than the cache size s.Put(origin, mockToken(1)) s.Put(origin, mockToken(2)) s.Put(origin, mockToken(3)) require.Equal(t, mockToken(3), s.Pop(origin)) s.Put(origin, mockToken(4)) s.Put(origin, mockToken(5)) require.Equal(t, mockToken(5), s.Pop(origin)) require.Equal(t, mockToken(4), s.Pop(origin)) require.Equal(t, mockToken(2), s.Pop(origin)) require.Nil(t, s.Pop(origin)) } func TestTokenStoreMultipleOrigins(t *testing.T) { s := NewLRUTokenStore(3, 4) s.Put("host1", mockToken(1)) s.Put("host2", mockToken(2)) s.Put("host3", mockToken(3)) s.Put("host4", mockToken(4)) require.Nil(t, s.Pop("host1")) require.Equal(t, mockToken(2), s.Pop("host2")) require.Equal(t, mockToken(3), s.Pop("host3")) require.Equal(t, mockToken(4), s.Pop("host4")) } func TestTokenStoreUpdates(t *testing.T) { s := NewLRUTokenStore(3, 4) s.Put("host1", mockToken(1)) s.Put("host2", mockToken(2)) s.Put("host3", mockToken(3)) s.Put("host1", mockToken(11)) // make sure one is evicted s.Put("host4", mockToken(4)) require.Nil(t, s.Pop("host2")) require.Equal(t, mockToken(11), s.Pop("host1")) require.Equal(t, mockToken(1), s.Pop("host1")) require.Equal(t, mockToken(3), s.Pop("host3")) require.Equal(t, mockToken(4), s.Pop("host4")) } func TestTokenStoreEviction(t *testing.T) { s := NewLRUTokenStore(3, 4) s.Put("host1", mockToken(1)) s.Put("host2", mockToken(2)) s.Put("host3", mockToken(3)) require.Equal(t, mockToken(2), s.Pop("host2")) require.Nil(t, s.Pop("host2")) // host2 is now empty and should have been deleted, making space for host4 s.Put("host4", mockToken(4)) require.Equal(t, mockToken(1), s.Pop("host1")) require.Equal(t, mockToken(3), s.Pop("host3")) require.Equal(t, mockToken(4), s.Pop("host4")) } quic-go-0.59.0/transport.go000066400000000000000000000621771513066070600155550ustar00rootroot00000000000000package quic import ( "context" "crypto/rand" "crypto/tls" "errors" "fmt" "net" "sync" "sync/atomic" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" ) // ErrTransportClosed is returned by the [Transport]'s Listen or Dial method after it was closed. var ErrTransportClosed = &errTransportClosed{} type errTransportClosed struct { err error } func (e *errTransportClosed) Unwrap() []error { return []error{net.ErrClosed, e.err} } func (e *errTransportClosed) Error() string { if e.err == nil { return "quic: transport closed" } return fmt.Sprintf("quic: transport closed: %s", e.err) } func (e *errTransportClosed) Is(target error) bool { _, ok := target.(*errTransportClosed) return ok } var errListenerAlreadySet = errors.New("listener already set") type closePacket struct { payload []byte addr net.Addr info packetInfo } // The Transport is the central point to manage incoming and outgoing QUIC connections. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // This means that a single UDP socket can be used for listening for incoming connections, as well as // for dialing an arbitrary number of outgoing connections. // A Transport handles a single net.PacketConn, and offers a range of configuration options // compared to the simple helper functions like [Listen] and [Dial] that this package provides. type Transport struct { // A single net.PacketConn can only be handled by one Transport. // Bad things will happen if passed to multiple Transports. // // A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface, // as a *net.UDPConn does. // 1. It enables the Don't Fragment (DF) bit on the IP header. // This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899). // 2. It enables reading of the ECN bits from the IP header. // This allows the remote node to speed up its loss detection and recovery. // 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket. // 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux). // // After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection. Conn net.PacketConn // The length of the connection ID in bytes. // It can be any value between 1 and 20. // Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes. // If unset, a 4 byte connection ID will be used. ConnectionIDLength int // Use for generating new connection IDs. // This allows the application to control of the connection IDs used, // which allows routing / load balancing based on connection IDs. // All Connection IDs returned by the ConnectionIDGenerator MUST // have the same length. ConnectionIDGenerator ConnectionIDGenerator // The StatelessResetKey is used to generate stateless reset tokens. // If no key is configured, sending of stateless resets is disabled. // It is highly recommended to configure a stateless reset key, as stateless resets // allow the peer to quickly recover from crashes and reboots of this node. // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey // The TokenGeneratorKey is used to encrypt session resumption tokens. // If no key is configured, a random key will be generated. // If multiple servers are authoritative for the same domain, they should use the same key, // see section 8.1.3 of RFC 9000 for details. TokenGeneratorKey *TokenGeneratorKey // MaxTokenAge is the maximum age of the resumption token presented during the handshake. // These tokens allow skipping address resumption when resuming a QUIC connection, // and are especially useful when using 0-RTT. // If not set, it defaults to 24 hours. // See section 8.1.3 of RFC 9000 for details. MaxTokenAge time.Duration // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // This can be useful if version information is exchanged out-of-band. // It has no effect for clients. DisableVersionNegotiationPackets bool // VerifySourceAddress decides if a connection attempt originating from unvalidated source // addresses first needs to go through source address validation using QUIC's Retry mechanism, // as described in RFC 9000 section 8.1.2. // Note that the address passed to this callback is unvalidated, and might be spoofed in case // of an attack. // Validating the source address adds one additional network roundtrip to the handshake, // and should therefore only be used if a suspiciously high number of incoming connection is recorded. // For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable // implementation of this callback (negating its return value). VerifySourceAddress func(net.Addr) bool // ConnContext is called when the server accepts a new connection. To reject a connection return // a non-nil error. // The context is closed when the connection is closed, or when the handshake fails for any reason. // The context returned from the callback is used to derive every other context used during the // lifetime of the connection: // * the context passed to crypto/tls (and used on the tls.ClientHelloInfo) // * the context used in Config.QlogTrace // * the context returned from Conn.Context // * the context returned from SendStream.Context // It is not used for dialed connections. ConnContext func(context.Context, *ClientInfo) (context.Context, error) // A Tracer traces events that don't belong to a single QUIC connection. // Recorder.Close is called when the transport is closed. Tracer qlogwriter.Recorder mutex sync.Mutex handlers map[protocol.ConnectionID]packetHandler resetTokens map[protocol.StatelessResetToken]packetHandler initOnce sync.Once initErr error // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. connIDLen int // Set in init. // If no ConnectionIDGenerator is set, this is set to a default. connIDGenerator ConnectionIDGenerator statelessResetter *statelessResetter server *baseServer conn rawConn closeQueue chan closePacket statelessResetQueue chan receivedPacket listening chan struct{} // is closed when listen returns closeErr error createdConn bool isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial readingNonQUICPackets atomic.Bool nonQUICPackets chan receivedPacket logger utils.Logger } // Listen starts listening for incoming QUIC connections. // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current listener was closed. func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { s, err := t.createServer(tlsConf, conf, false) if err != nil { return nil, err } return &Listener{baseServer: s}, nil } // ListenEarly starts listening for incoming QUIC connections. // There can only be a single listener on any net.PacketConn. // ListenEarly may only be called again after the current listener was closed. func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { s, err := t.createServer(tlsConf, conf, true) if err != nil { return nil, err } return &EarlyListener{baseServer: s}, nil } func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } if err := validateConfig(conf); err != nil { return nil, err } t.mutex.Lock() defer t.mutex.Unlock() if t.closeErr != nil { return nil, t.closeErr } if t.server != nil { return nil, errListenerAlreadySet } conf = populateConfig(conf) if err := t.init(false); err != nil { return nil, err } maxTokenAge := t.MaxTokenAge if maxTokenAge == 0 { maxTokenAge = 24 * time.Hour } s := newServer( t.conn, (*packetHandlerMap)(t), t.connIDGenerator, t.statelessResetter, t.ConnContext, tlsConf, conf, t.Tracer, t.closeServer, *t.TokenGeneratorKey, maxTokenAge, t.VerifySourceAddress, t.DisableVersionNegotiationPackets, allow0RTT, ) t.server = s return s, nil } // Dial dials a new connection to a remote host (not using 0-RTT). func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) { return t.dial(ctx, addr, "", tlsConf, conf, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) { return t.dial(ctx, addr, "", tlsConf, conf, true) } func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (*Conn, error) { if err := t.init(t.isSingleUse); err != nil { return nil, err } if err := validateConfig(conf); err != nil { return nil, err } conf = populateConfig(conf) tlsConf = tlsConf.Clone() setTLSConfigServerName(tlsConf, addr, host) return t.doDial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), tlsConf, conf, 0, false, use0RTT, conf.Versions[0], ) } func (t *Transport) doDial( ctx context.Context, sendConn sendConn, tlsConf *tls.Config, config *Config, initialPacketNumber protocol.PacketNumber, hasNegotiatedVersion bool, use0RTT bool, version protocol.Version, ) (*Conn, error) { srcConnID, err := t.connIDGenerator.GenerateConnectionID() if err != nil { return nil, err } destConnID, err := generateConnectionIDForInitial() if err != nil { return nil, err } t.mutex.Lock() if t.closeErr != nil { t.mutex.Unlock() return nil, t.closeErr } var qlogTrace qlogwriter.Trace if config.Tracer != nil { qlogTrace = config.Tracer(ctx, true, destConnID) } logger := utils.DefaultLogger.WithPrefix("client") logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version) conn := newClientConnection( context.WithoutCancel(ctx), sendConn, (*packetHandlerMap)(t), destConnID, srcConnID, t.connIDGenerator, t.statelessResetter, config, tlsConf, initialPacketNumber, use0RTT, hasNegotiatedVersion, qlogTrace, logger, version, ) t.handlers[srcConnID] = conn t.mutex.Unlock() // The error channel needs to be buffered, as the run loop will continue running // after doDial returns (if the handshake is successful). // Similarly, the recreateChan needs to be buffered; in case a different case is selected. errChan := make(chan error, 1) recreateChan := make(chan errCloseForRecreating, 1) go func() { err := conn.run() var recreateErr *errCloseForRecreating if errors.As(err, &recreateErr) { recreateChan <- *recreateErr return } if t.isSingleUse { t.Close() } errChan <- err }() // Only set when we're using 0-RTT. // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. var earlyConnChan <-chan struct{} if use0RTT { earlyConnChan = conn.earlyConnReady() } select { case <-ctx.Done(): conn.destroy(nil) // wait until the Go routine that called Conn.run() returns select { case <-errChan: case <-recreateChan: } return nil, context.Cause(ctx) case params := <-recreateChan: return t.doDial(ctx, sendConn, tlsConf, config, params.nextPacketNumber, true, use0RTT, params.nextVersion, ) case err := <-errChan: return nil, err case <-earlyConnChan: // ready to send 0-RTT data return conn.Conn, nil case <-conn.HandshakeComplete(): // handshake successfully completed return conn.Conn, nil } } func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.initOnce.Do(func() { var conn rawConn if c, ok := t.Conn.(rawConn); ok { conn = c } else { var err error conn, err = wrapConn(t.Conn) if err != nil { t.initErr = err return } } t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn t.handlers = make(map[protocol.ConnectionID]packetHandler) t.resetTokens = make(map[protocol.StatelessResetToken]packetHandler) t.listening = make(chan struct{}) t.closeQueue = make(chan closePacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4) if t.TokenGeneratorKey == nil { var key TokenGeneratorKey if _, err := rand.Read(key[:]); err != nil { t.initErr = err return } t.TokenGeneratorKey = &key } if t.ConnectionIDGenerator != nil { t.connIDGenerator = t.ConnectionIDGenerator t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() } else { connIDLen := t.ConnectionIDLength if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs { connIDLen = protocol.DefaultConnectionIDLength } t.connIDLen = connIDLen t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } t.statelessResetter = newStatelessResetter(t.StatelessResetKey) go func() { defer close(t.listening) t.listen(conn) if t.createdConn { conn.Close() } }() go t.runSendQueue() }) return t.initErr } // WriteTo sends a packet on the underlying connection. func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) } func (t *Transport) runSendQueue() { for { select { case <-t.listening: return case p := <-t.closeQueue: t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } } } // Close stops listening for UDP datagrams on the Transport.Conn. // It abruptly terminates all existing connections, without sending a CONNECTION_CLOSE // to the peers. It is the application's responsibility to cleanly terminate existing // connections prior to calling Close. // // If a server was started, it will be closed as well. // It is not possible to start any new server or dial new connections after that. func (t *Transport) Close() error { // avoid race condition if the transport is currently being initialized t.init(false) t.close(nil) if t.createdConn { if err := t.Conn.Close(); err != nil { return err } } else if t.conn != nil { t.conn.SetReadDeadline(time.Now()) defer func() { t.conn.SetReadDeadline(time.Time{}) }() } if t.listening != nil { <-t.listening // wait until listening returns } return nil } func (t *Transport) closeServer() { t.mutex.Lock() defer t.mutex.Unlock() t.server = nil if t.isSingleUse { t.closeErr = ErrServerClosed } if len(t.handlers) == 0 { t.maybeStopListening() } } func (t *Transport) close(e error) { t.mutex.Lock() if t.closeErr != nil { t.mutex.Unlock() return } e = &errTransportClosed{err: e} t.closeErr = e server := t.server t.server = nil if server != nil { t.mutex.Unlock() server.close(e, true) t.mutex.Lock() } // Close existing connections var wg sync.WaitGroup for _, handler := range t.handlers { wg.Add(1) go func(handler packetHandler) { handler.destroy(e) wg.Done() }(handler) } t.mutex.Unlock() // closing connections requires releasing transport mutex wg.Wait() if t.Tracer != nil { t.Tracer.Close() } } // only print warnings about the UDP receive buffer size once var setBufferWarningOnce sync.Once func (t *Transport) listen(conn rawConn) { for { p, err := conn.ReadPacket() //nolint:staticcheck // SA1019 ignore this! // TODO: This code is used to ignore wsa errors on Windows. // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. // See https://github.com/quic-go/quic-go/issues/1737 for details. if nerr, ok := err.(net.Error); ok && nerr.Temporary() { t.mutex.Lock() closed := t.closeErr != nil t.mutex.Unlock() if closed { return } t.logger.Debugf("Temporary error reading from conn: %w", err) continue } if err != nil { // Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer. if isRecvMsgSizeErr(err) { continue } t.close(err) return } t.handlePacket(p) } } func (t *Transport) maybeStopListening() { if t.isSingleUse && t.closeErr != nil { t.conn.SetReadDeadline(time.Now()) } } func (t *Transport) handlePacket(p receivedPacket) { if len(p.data) == 0 { return } if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) { t.handleNonQUICPacket(p) return } connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) if t.Tracer != nil { t.Tracer.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropHeaderParseError, }) } p.buffer.MaybeRelease() return } // If there's a connection associated with the connection ID, pass the packet there. if handler, ok := (*packetHandlerMap)(t).Get(connID); ok { handler.handlePacket(p) return } // RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both // packets that cannot be associated with any connections, and for packets that can't be decrypted. // We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an // existing connection, it is dropped there if if it can't be decrypted. // Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are // exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection, // it is to be expected that the next stateless reset will be correctly detected. if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { return } if !wire.IsLongHeaderPacket(p.data[0]) { if statelessResetQueued := t.maybeSendStatelessReset(p); !statelessResetQueued { if t.Tracer != nil { t.Tracer.RecordEvent(qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT}, Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnknownConnectionID, }) } p.buffer.Release() } return } t.mutex.Lock() defer t.mutex.Unlock() if t.server == nil { // no server set t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) if t.Tracer != nil { t.Tracer.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropUnknownConnectionID, }) } p.buffer.MaybeRelease() return } t.server.handlePacket(p) } func (t *Transport) maybeSendStatelessReset(p receivedPacket) (statelessResetQueued bool) { if t.StatelessResetKey == nil { return false } // Don't send a stateless reset in response to very small packets. // This includes packets that could be stateless resets. if len(p.data) <= protocol.MinStatelessResetSize { return false } select { case t.statelessResetQueue <- p: return true default: // it's fine to not send a stateless reset when we're busy return false } } func (t *Transport) sendStatelessReset(p receivedPacket) { defer p.buffer.Release() connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) return } token := t.statelessResetter.GetStatelessResetToken(connID) t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } func (t *Transport) maybeHandleStatelessReset(data []byte) bool { // stateless resets are always short header packets if wire.IsLongHeaderPacket(data[0]) { return false } if len(data) < 17 /* type byte + 16 bytes for the reset token */ { return false } token := protocol.StatelessResetToken(data[len(data)-16:]) t.mutex.Lock() conn, ok := t.resetTokens[token] t.mutex.Unlock() if ok { t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) go conn.destroy(&StatelessResetError{}) return true } return false } func (t *Transport) handleNonQUICPacket(p receivedPacket) { // Strictly speaking, this is racy, // but we only care about receiving packets at some point after ReadNonQUICPacket has been called. if !t.readingNonQUICPackets.Load() { return } select { case t.nonQUICPackets <- p: default: if t.Tracer != nil { t.Tracer.RecordEvent(qlog.PacketDropped{ Raw: qlog.RawInfo{Length: int(p.Size())}, Trigger: qlog.PacketDropDOSPrevention, }) } } } const maxQueuedNonQUICPackets = 32 // ReadNonQUICPacket reads non-QUIC packets received on the underlying connection. // The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0. // Note that this is stricter than the detection logic defined in RFC 9443. func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { if err := t.init(false); err != nil { return 0, nil, err } if !t.readingNonQUICPackets.Load() { t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets) t.readingNonQUICPackets.Store(true) } select { case <-ctx.Done(): return 0, nil, ctx.Err() case p := <-t.nonQUICPackets: n := copy(b, p.data) return n, p.remoteAddr, nil case <-t.listening: return 0, nil, errors.New("closed") } } func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { // If no ServerName is set, infer the ServerName from the host we're connecting to. if tlsConf.ServerName != "" { return } if host == "" { if udpAddr, ok := addr.(*net.UDPAddr); ok { tlsConf.ServerName = udpAddr.IP.String() return } } h, _, err := net.SplitHostPort(host) if err != nil { // This happens if the host doesn't contain a port number. tlsConf.ServerName = host return } tlsConf.ServerName = h } type packetHandlerMap Transport var _ connRunner = &packetHandlerMap{} func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { h.mutex.Lock() defer h.mutex.Unlock() if _, ok := h.handlers[id]; ok { h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) return false } h.handlers[id] = handler h.logger.Debugf("Adding connection ID %s.", id) return true } func (h *packetHandlerMap) Get(connID protocol.ConnectionID) (packetHandler, bool) { h.mutex.Lock() defer h.mutex.Unlock() handler, ok := h.handlers[connID] return handler, ok } func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { h.mutex.Lock() h.resetTokens[token] = handler h.mutex.Unlock() } func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { h.mutex.Lock() delete(h.resetTokens, token) h.mutex.Unlock() } func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool { h.mutex.Lock() defer h.mutex.Unlock() if _, ok := h.handlers[clientDestConnID]; ok { h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } h.handlers[clientDestConnID] = handler h.handlers[newConnID] = handler h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { h.mutex.Lock() delete(h.handlers, id) h.mutex.Unlock() h.logger.Debugf("Removing connection ID %s.", id) } // ReplaceWithClosed is called when a connection is closed. // Depending on which side closed the connection, we need to: // * remote close: absorb delayed packets // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) { var handler packetHandler if connClosePacket != nil { handler = newClosedLocalConn( func(addr net.Addr, info packetInfo) { select { case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}: default: // We're backlogged. // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. } }, h.logger, ) } else { handler = newClosedRemoteConn() } h.mutex.Lock() for _, id := range ids { h.handlers[id] = handler } h.mutex.Unlock() h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) time.AfterFunc(expiry, func() { h.mutex.Lock() for _, id := range ids { delete(h.handlers, id) } if len(h.handlers) == 0 { t := (*Transport)(h) t.maybeStopListening() } h.mutex.Unlock() h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) }) } quic-go-0.59.0/transport_test.go000066400000000000000000000506511513066070600166060ustar00rootroot00000000000000package quic import ( "bytes" "context" "crypto/tls" "errors" "math" "net" "runtime" "strings" "sync/atomic" "syscall" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" "github.com/quic-go/quic-go/testutils/events" "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type mockPacketConn struct { localAddr net.Addr readErrs chan error } func (c *mockPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { err, ok := <-c.readErrs if !ok { return 0, nil, net.ErrClosed } return 0, nil, err } func (c *mockPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { panic("implement me") } func (c *mockPacketConn) LocalAddr() net.Addr { return c.localAddr } func (c *mockPacketConn) Close() error { close(c.readErrs); return nil } func (c *mockPacketConn) SetDeadline(t time.Time) error { return nil } func (c *mockPacketConn) SetReadDeadline(t time.Time) error { return nil } func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil } type mockPacketHandler struct { packets chan<- receivedPacket destruction chan<- error } func (h *mockPacketHandler) handlePacket(p receivedPacket) { h.packets <- p } func (h *mockPacketHandler) destroy(err error) { if h.destruction != nil { h.destruction <- err } } func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {} func newSimnetLink(t *testing.T, rtt time.Duration) (client, server net.PacketConn, close func()) { t.Helper() n := &simnet.Simnet{Router: &simnet.PerfectRouter{}} settings := simnet.NodeBiDiLinkSettings{Latency: rtt / 2} client = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9001}, settings) server = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9002}, settings) require.NoError(t, n.Start()) return client, server, func() { require.NoError(t, n.Close()) } } func TestTransportPacketHandling(t *testing.T) { tr := &Transport{Conn: newUDPConnLocalhost(t)} tr.init(true) defer tr.Close() connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) connChan1 := make(chan receivedPacket, 1) conn1 := &mockPacketHandler{packets: connChan1} (*packetHandlerMap)(tr).Add(connID1, conn1) connChan2 := make(chan receivedPacket, 1) conn2 := &mockPacketHandler{packets: connChan2} (*packetHandlerMap)(tr).Add(connID2, conn2) conn := newUDPConnLocalhost(t) _, err := conn.WriteTo(getPacket(t, connID1), tr.Conn.LocalAddr()) require.NoError(t, err) _, err = conn.WriteTo(getPacket(t, connID2), tr.Conn.LocalAddr()) require.NoError(t, err) select { case p := <-connChan1: require.Equal(t, conn.LocalAddr(), p.remoteAddr) connID, err := wire.ParseConnectionID(p.data, 0) require.NoError(t, err) require.Equal(t, connID1, connID) case <-time.After(time.Second): t.Fatal("timeout") } select { case p := <-connChan2: require.Equal(t, conn.LocalAddr(), p.remoteAddr) connID, err := wire.ParseConnectionID(p.data, 0) require.NoError(t, err) require.Equal(t, connID2, connID) case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportAndListenerConcurrentClose(t *testing.T) { tr := &Transport{Conn: newUDPConnLocalhost(t)} ln, err := tr.Listen(&tls.Config{}, nil) require.NoError(t, err) // close transport and listener concurrently lnErrChan := make(chan error, 1) go func() { lnErrChan <- ln.Close() }() require.NoError(t, tr.Close()) select { case err := <-lnErrChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportAndDialConcurrentClose(t *testing.T) { server := newUDPConnLocalhost(t) tr := &Transport{Conn: newUDPConnLocalhost(t)} // close transport and dial concurrently errChan := make(chan error, 1) go func() { errChan <- tr.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() _, err := tr.Dial(ctx, server.LocalAddr(), &tls.Config{}, nil) require.Error(t, err) require.ErrorIs(t, err, ErrTransportClosed) require.NotErrorIs(t, err, context.DeadlineExceeded) select { case <-errChan: case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportErrFromConn(t *testing.T) { t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") synctest.Test(t, func(t *testing.T) { readErrChan := make(chan error, 2) tr := Transport{ Conn: &mockPacketConn{ readErrs: readErrChan, localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}, }, } defer tr.Close() tr.init(true) errChan := make(chan error, 1) ph := &mockPacketHandler{destruction: errChan} (*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph) // temporary errors don't lead to a shutdown... var tempErr deadlineError require.True(t, tempErr.Temporary()) readErrChan <- tempErr // don't expect any calls to phm.Close synctest.Wait() // ...but non-temporary errors do readErrChan <- errors.New("read failed") synctest.Wait() select { case err := <-errChan: require.ErrorIs(t, err, ErrTransportClosed) case <-time.After(time.Second): t.Fatal("timeout") } _, err := tr.Listen(&tls.Config{}, nil) require.ErrorIs(t, err, ErrTransportClosed) }) } func TestTransportStatelessResetReceiving(t *testing.T) { tr := &Transport{ Conn: newUDPConnLocalhost(t), ConnectionIDLength: 4, } tr.init(true) defer tr.Close() connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) // now send a packet with a connection ID that doesn't exist token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) require.NoError(t, err) b = append(b, token[:]...) destroyChan := make(chan error, 1) conn1 := &mockPacketHandler{destruction: destroyChan} (*packetHandlerMap)(tr).AddResetToken(token, conn1) conn := newUDPConnLocalhost(t) _, err = conn.WriteTo(b, tr.Conn.LocalAddr()) require.NoError(t, err) select { case err := <-destroyChan: require.ErrorIs(t, err, &qerr.StatelessResetError{}) case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportStatelessResetSending(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond clientConn, serverConn, closeFn := newSimnetLink(t, rtt) defer closeFn() var eventRecorder events.Recorder tr := &Transport{ Conn: serverConn, ConnectionIDLength: 4, StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, Tracer: &eventRecorder, } tr.init(true) defer tr.Close() connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) // now send a packet with a connection ID that doesn't exist b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) require.NoError(t, err) // no stateless reset sent for packets smaller than MinStatelessResetSize smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...) _, err = clientConn.WriteTo(smallPacket, tr.Conn.LocalAddr()) require.NoError(t, err) time.Sleep(rtt) // so that the packet arrives at the server require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT}, Raw: qlog.RawInfo{Length: len(smallPacket)}, Trigger: qlog.PacketDropUnknownConnectionID, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) // but a stateless reset is sent for packets larger than MinStatelessResetSize _, err = clientConn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr()) require.NoError(t, err) clientConn.SetReadDeadline(time.Now().Add(time.Second)) p := make([]byte, 1024) n, addr, err := clientConn.ReadFrom(p) require.NoError(t, err) require.Equal(t, addr, tr.Conn.LocalAddr()) srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID) require.Contains(t, string(p[:n]), string(srt[:])) }) } func TestTransportUnparseableQUICPackets(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond clientConn, serverConn, closeFn := newSimnetLink(t, rtt) defer closeFn() var eventRecorder events.Recorder tr := &Transport{ Conn: serverConn, ConnectionIDLength: 10, Tracer: &eventRecorder, } require.NoError(t, tr.init(true)) defer tr.Close() _, err := clientConn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr()) require.NoError(t, err) time.Sleep(rtt) // so that the packet arrives at the server require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Raw: qlog.RawInfo{Length: 4}, Trigger: qlog.PacketDropHeaderParseError, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) }) } func TestTransportListening(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond clientConn, serverConn, closeFn := newSimnetLink(t, rtt) defer closeFn() var eventRecorder events.Recorder tr := &Transport{ Conn: serverConn, ConnectionIDLength: 5, Tracer: &eventRecorder, } require.NoError(t, tr.init(true)) defer tr.Close() data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1}) _, err := clientConn.WriteTo(data, tr.Conn.LocalAddr()) require.NoError(t, err) time.Sleep(rtt) // so that the packet arrives at the server require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Raw: qlog.RawInfo{Length: len(data)}, Trigger: qlog.PacketDropUnknownConnectionID, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) eventRecorder.Clear() ln, err := tr.Listen(&tls.Config{}, nil) require.NoError(t, err) _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) require.NoError(t, err) time.Sleep(rtt) // so that the packet arrives at the server require.Equal(t, []qlogwriter.Event{ qlog.PacketDropped{ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation}, Raw: qlog.RawInfo{Length: len(data)}, Trigger: qlog.PacketDropUnexpectedPacket, }, }, eventRecorder.Events(qlog.PacketDropped{}), ) // only a single listener can be set _, err = tr.Listen(&tls.Config{}, nil) require.Error(t, err) require.ErrorIs(t, err, errListenerAlreadySet) require.NoError(t, ln.Close()) // now it's possible to add a new listener ln, err = tr.Listen(&tls.Config{}, nil) require.NoError(t, err) defer ln.Close() }) } func TestTransportNonQUICPackets(t *testing.T) { synctest.Test(t, func(t *testing.T) { const rtt = 10 * time.Millisecond clientConn, serverConn, closeFn := newSimnetLink(t, rtt) defer closeFn() tr := &Transport{Conn: serverConn} defer tr.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024)) require.Error(t, err) require.ErrorIs(t, err, context.DeadlineExceeded) data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3} _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) require.NoError(t, err) _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() b := make([]byte, 1024) n, addr, err := tr.ReadNonQUICPacket(ctx, b) require.NoError(t, err) require.Equal(t, data, b[:n]) require.Equal(t, addr, clientConn.LocalAddr()) // now send a lot of packets without reading them for i := range 2 * maxQueuedNonQUICPackets { data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...) _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr()) require.NoError(t, err) } time.Sleep(rtt) // so that all packets arrive at the server var received int for { ctx, cancel = context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() _, _, err := tr.ReadNonQUICPacket(ctx, b) if errors.Is(err, context.DeadlineExceeded) { break } require.NoError(t, err) received++ } require.Equal(t, received, maxQueuedNonQUICPackets) }) } type faultySyscallConn struct{ net.PacketConn } func (c *faultySyscallConn) SyscallConn() (syscall.RawConn, error) { return nil, assert.AnError } func TestTransportFaultySyscallConn(t *testing.T) { syscallconn := &faultySyscallConn{PacketConn: newUDPConnLocalhost(t)} tr := &Transport{Conn: syscallconn} _, err := tr.Listen(&tls.Config{}, nil) require.Error(t, err) require.ErrorIs(t, err, assert.AnError) } func TestTransportSetTLSConfigServerName(t *testing.T) { for _, tt := range []struct { name string expected string conf *tls.Config host string }{ { name: "uses the value from the config", expected: "foo.bar", conf: &tls.Config{ServerName: "foo.bar"}, host: "baz.foo", }, { name: "uses the hostname", expected: "golang.org", conf: &tls.Config{}, host: "golang.org", }, { name: "removes the port from the hostname", expected: "golang.org", conf: &tls.Config{}, host: "golang.org:1234", }, { name: "uses the IP", expected: "1.3.5.7", conf: &tls.Config{}, host: "", }, } { t.Run(tt.name, func(t *testing.T) { setTLSConfigServerName(tt.conf, &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}, tt.host) require.Equal(t, tt.expected, tt.conf.ServerName) }) } } func TestTransportDial(t *testing.T) { t.Run("regular", func(t *testing.T) { testTransportDial(t, false) }) t.Run("early", func(t *testing.T) { testTransportDial(t, true) }) } func testTransportDial(t *testing.T, early bool) { originalClientConnConstructor := newClientConnection t.Cleanup(func() { newClientConnection = originalClientConnConstructor }) synctest.Test(t, func(t *testing.T) { _, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond) defer closeFn() var conn *connTestHooks handshakeChan := make(chan struct{}) blockRun := make(chan struct{}) if early { conn = &connTestHooks{ earlyConnReady: func() <-chan struct{} { return handshakeChan }, handshakeComplete: func() <-chan struct{} { return make(chan struct{}) }, } } else { conn = &connTestHooks{ handshakeComplete: func() <-chan struct{} { return handshakeChan }, } } conn.run = func() error { <-blockRun; return errors.New("done") } defer close(blockRun) newClientConnection = func( _ context.Context, _ sendConn, _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ *statelessResetter, _ *Config, _ *tls.Config, _ protocol.PacketNumber, _ bool, _ bool, _ qlogwriter.Trace, _ utils.Logger, _ protocol.Version, ) *wrappedConn { return &wrappedConn{testHooks: conn} } tr := &Transport{Conn: serverConn} tr.init(true) defer tr.Close() errChan := make(chan error, 1) go func() { var err error if early { _, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil) } else { _, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil) } errChan <- err }() synctest.Wait() select { case <-errChan: t.Fatal("Dial shouldn't have returned") default: } close(handshakeChan) synctest.Wait() select { case err := <-errChan: require.NoError(t, err) default: } }) } func TestTransportDialingVersionNegotiation(t *testing.T) { originalClientConnConstructor := newClientConnection t.Cleanup(func() { newClientConnection = originalClientConnConstructor }) conn := &connTestHooks{ handshakeComplete: func() <-chan struct{} { return make(chan struct{}) }, run: func() error { return &errCloseForRecreating{nextPacketNumber: 109, nextVersion: 789} }, } conn2 := &connTestHooks{ handshakeComplete: func() <-chan struct{} { return make(chan struct{}) }, run: func() error { return assert.AnError }, } type connParams struct { pn protocol.PacketNumber hasNegotiatedVersion bool version protocol.Version } connChan := make(chan connParams, 2) var counter int newClientConnection = func( _ context.Context, _ sendConn, _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ ConnectionIDGenerator, _ *statelessResetter, _ *Config, _ *tls.Config, pn protocol.PacketNumber, _ bool, hasNegotiatedVersion bool, _ qlogwriter.Trace, _ utils.Logger, v protocol.Version, ) *wrappedConn { connChan <- connParams{pn: pn, hasNegotiatedVersion: hasNegotiatedVersion, version: v} if counter == 0 { counter++ return &wrappedConn{testHooks: conn} } return &wrappedConn{testHooks: conn2} } tr := &Transport{Conn: newUDPConnLocalhost(t)} tr.init(true) defer tr.Close() _, err := tr.Dial(context.Background(), nil, &tls.Config{}, nil) require.ErrorIs(t, err, assert.AnError) select { case params := <-connChan: require.Zero(t, params.pn) require.False(t, params.hasNegotiatedVersion) require.Equal(t, protocol.Version1, params.version) case <-time.After(time.Second): t.Fatal("timeout") } select { case params := <-connChan: require.Equal(t, protocol.PacketNumber(109), params.pn) require.True(t, params.hasNegotiatedVersion) require.Equal(t, protocol.Version(789), params.version) case <-time.After(time.Second): t.Fatal("timeout") } } func TestTransportReplaceWithClosed(t *testing.T) { // synctest works slightly differently on Go 1.24, // so we skip the test if strings.HasPrefix(runtime.Version(), "go1.24") { t.Skip("skipping on Go 1.24 due to synctest issues") } t.Run("local", func(t *testing.T) { testTransportReplaceWithClosed(t, true) }) t.Run("remote", func(t *testing.T) { testTransportReplaceWithClosed(t, false) }) } func testTransportReplaceWithClosed(t *testing.T, local bool) { synctest.Test(t, func(t *testing.T) { clientConn, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond) defer closeFn() srk := StatelessResetKey{1, 2, 3, 4} tr := &Transport{ Conn: serverConn, ConnectionIDLength: 4, StatelessResetKey: &srk, } tr.init(true) defer tr.Close() var closePacket []byte if local { closePacket = []byte("foobar") } const expiry = 50 * time.Millisecond handler := &mockPacketHandler{} connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) m := (*packetHandlerMap)(tr) require.True(t, m.Add(connID, handler)) m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, expiry) p := make([]byte, 100) p[0] = 0x40 // QUIC bit copy(p[1:], connID.Bytes()) var sent atomic.Int64 errChan := make(chan error, 1) stopSending := make(chan struct{}) go func() { defer close(errChan) ticker := time.NewTicker(expiry / 200) timeout := time.NewTimer(time.Second) for { select { case <-stopSending: return case <-timeout.C: errChan <- errors.New("timeout") return case <-ticker.C: } if _, err := clientConn.WriteTo(p, tr.Conn.LocalAddr()); err != nil { errChan <- err return } sent.Add(1) } }() // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff var received int clientConn.SetReadDeadline(time.Now().Add(time.Hour)) for { b := make([]byte, 100) n, _, err := clientConn.ReadFrom(b) require.NoError(t, err) // at some point, the connection is cleaned up, and we'll receive a stateless reset if !bytes.Equal(b[:n], []byte("foobar")) { require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) close(stopSending) // stop sending packets break } received++ } select { case err := <-errChan: require.NoError(t, err) case <-time.After(time.Second): t.Fatal("timeout") } numSent := sent.Load() if !local { require.Zero(t, received) t.Logf("sent %d packets", numSent) return } t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received) require.Equal(t, int(math.Ceil(math.Log2(float64(numSent)))), received) }) }