openucx-ucc-ec0bc8a/0000775000175000017500000000000015133731560014702 5ustar alastairalastairopenucx-ucc-ec0bc8a/.clang-tidy0000664000175000017500000000056315133731560016742 0ustar alastairalastair--- Checks: > -*, clang-analyzer-*, -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, -clang-analyzer-osx.*, -clang-analyzer-optin.osx.*, misc-redundant-expression, misc-static-assert, misc-unused-parameters, readability-redundant-control-flow WarningsAsErrors: '*' HeaderFilterRegex: '' FormatStyle: none ... openucx-ucc-ec0bc8a/tools/0000775000175000017500000000000015133731560016042 5ustar alastairalastairopenucx-ucc-ec0bc8a/tools/info/0000775000175000017500000000000015133731560016775 5ustar alastairalastairopenucx-ucc-ec0bc8a/tools/info/build_info.c0000664000175000017500000000137015133731560021254 0ustar alastairalastair/** * Copyright (c) 2001-2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "config.h" #include "ucc_info.h" #include "utils/ucc_compiler_def.h" void print_version() { printf("# UCC version=%s revision %s\n", UCC_VERSION_STRING, UCC_GIT_REVISION); printf("# Configured with: %s\n", UCC_CONFIGURE_FLAGS); } void print_build_config() { typedef struct { const char *name; const char *value; } config_var_t; static config_var_t config_vars[] = { #include {NULL, NULL} }; config_var_t *var; for (var = config_vars; var->name != NULL; ++var) { printf("#define %-25s %s\n", var->name, var->value); } } openucx-ucc-ec0bc8a/tools/info/ucc_info.h0000664000175000017500000000051515133731560020734 0ustar alastairalastair/** * Copyright (c) 2001-2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_INFO_H #define UCC_INFO_H #include "ucc/api/ucc.h" enum { PRINT_VERSION = UCC_BIT(0), PRINT_BUILD_CONFIG = UCC_BIT(1), }; void print_version(); void print_build_config(); #endif openucx-ucc-ec0bc8a/tools/info/ucc_info.c0000664000175000017500000001302415133731560020726 0ustar alastairalastair/** * Copyright (c) 2001-2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "config.h" #include "ucc_info.h" #include "core/ucc_global_opts.h" #include "utils/ucc_parser.h" #include "utils/ucc_log.h" #include "utils/ucc_datastruct.h" #include "components/tl/ucc_tl.h" #include "components/cl/ucc_cl.h" #include #include static void usage() { printf("Usage: ucc_info [options]\n"); printf("At least one of the following options has to be set:\n"); printf(" -v Show version information\n"); printf(" -b Show build configuration\n"); printf(" -c Show UCC configuration\n"); printf(" -a Show also hidden configuration\n"); printf(" -f Show fully decorated output\n"); printf(" -s Show default components scores\n"); printf(" -A Show collective algorithms available for selection\n"); printf(" -h Show this help message\n"); printf("\n"); } extern ucc_list_link_t ucc_config_global_list; static void print_algorithm_info(ucc_base_coll_alg_info_t *info) { while (info->name) { printf(" %u : %16s : %s\n", info->id, info->name, info->desc); info++; } } static void print_component_algs(ucc_base_coll_alg_info_t **alg_info, const char *component, const char *component_name) { int have_algs = 0; int i; for (i = 0; i < UCC_COLL_TYPE_NUM; i++) { if (alg_info[i]) { have_algs = 1; break; } } if (have_algs) { printf("%s/%s algorithms:\n", component, component_name); for (i = 0; i < UCC_COLL_TYPE_NUM; i++) { if (alg_info[i]) { printf(" %s\n", ucc_coll_type_str((ucc_coll_type_t)UCC_BIT(i))); print_algorithm_info(alg_info[i]); } } printf("\n"); } } int main(int argc, char **argv) { ucc_global_config_t *cfg = &ucc_global_config; ucc_config_print_flags_t print_flags; unsigned print_opts; int c, show_scores, show_algs; ucc_lib_h lib; ucc_lib_config_h config; ucc_lib_params_t params; ucc_status_t status; ucc_tl_iface_t * tl; ucc_cl_iface_t * cl; print_flags = (ucc_config_print_flags_t)0; print_opts = 0; show_scores = 0; show_algs = 0; while ((c = getopt(argc, argv, "vbcafhsA")) != -1) { switch (c) { case 'f': print_flags |= (ucc_config_print_flags_t)(UCC_CONFIG_PRINT_CONFIG | UCC_CONFIG_PRINT_HEADER | UCC_CONFIG_PRINT_DOC); break; case 'c': print_flags |= (ucc_config_print_flags_t)UCC_CONFIG_PRINT_CONFIG; break; case 'a': print_flags |= (ucc_config_print_flags_t)UCC_CONFIG_PRINT_HIDDEN; break; case 'v': print_opts |= PRINT_VERSION; break; case 'b': print_opts |= PRINT_BUILD_CONFIG; break; case 's': show_scores = 1; break; case 'A': show_algs = 1; break; case 'h': usage(); return 0; default: usage(); return -1; } } if ((print_opts == 0) && (print_flags == 0) && (!show_scores) && (!show_algs)) { usage(); return -2; } /* need to call ucc_init to force loading of dynamic ucc components */ params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; params.thread_mode = UCC_THREAD_SINGLE; if (UCC_OK != ucc_lib_config_read(NULL, NULL, &config)) { return 0; } status = ucc_init(¶ms, config, &lib); ucc_lib_config_release(config); if (UCC_OK != status) { return 0; } if (print_opts & PRINT_VERSION) { print_version(); } if (print_opts & PRINT_BUILD_CONFIG) { print_build_config(); } if (print_flags & UCC_CONFIG_PRINT_CONFIG) { ucc_config_parser_print_all_opts(stdout, "UCC_", print_flags, &ucc_config_global_list); } if (show_scores) { if (cfg->cl_framework.n_components) { printf("Default CLs scores:"); for (c = 0; c < cfg->cl_framework.n_components; c++) { printf(" %s=%d", cfg->cl_framework.components[c]->name, cfg->cl_framework.components[c]->score); } printf("\n"); } if (cfg->tl_framework.n_components) { printf("Default TLs scores:"); for (c = 0; c < cfg->tl_framework.n_components; c++) { printf(" %s=%d", cfg->tl_framework.components[c]->name, cfg->tl_framework.components[c]->score); } printf("\n"); } } if (show_algs) { for (c = 0; c < cfg->cl_framework.n_components; c++) { cl = ucc_derived_of(cfg->cl_framework.components[c], ucc_cl_iface_t); print_component_algs(cl->alg_info, "cl", cl->super.name); } for (c = 0; c < cfg->tl_framework.n_components; c++) { tl = ucc_derived_of(cfg->tl_framework.components[c], ucc_tl_iface_t); print_component_algs(tl->alg_info, "tl", tl->super.name); } } ucc_finalize(lib); return 0; } openucx-ucc-ec0bc8a/tools/info/Makefile.am0000664000175000017500000000164115133731560021033 0ustar alastairalastair# # Copyright (c) 2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (C) The University of Tennessee and the University of Tennessee Research Foundation. 2016. ALL RIGHTS RESERVED. # # See file LICENSE for terms. # bin_PROGRAMS = ucc_info BUILT_SOURCES = build_config.h DISTCLEANFILES = build_config.h # # Produce a C header file which contains all defined variables from config.h # build_config.h: $(top_builddir)/config.h Makefile $(SED) -nr 's:\s*#define\s+(\w+)(\s+(\w+)|\s+(".*")|\s*)$$:{"\1", UCC_PP_MAKE_STRING(\3\4)},:p' <$(top_builddir)/config.h >$@ ucc_info_SOURCES = \ build_info.c \ ucc_info.c noinst_HEADERS = \ ucc_info.h nodist_ucc_info_SOURCES = \ build_config.h ucc_info_CPPFLAGS = $(AM_CPPFLAGS) $(BASE_CPPFLAGS) -I${UCC_TOP_BUILDDIR}/src ucc_info_CFLAGS = $(BASE_CFLAGS) ucc_info_LDFLAGS = -Wl,--rpath-link=${UCS_LIBDIR} ucc_info_LDADD = \ $(UCC_TOP_BUILDDIR)/src/libucc.la openucx-ucc-ec0bc8a/tools/perf/0000775000175000017500000000000015133731560016776 5ustar alastairalastairopenucx-ucc-ec0bc8a/tools/perf/ucc_pt_op_reduce_strided.cc0000664000175000017500000000544715133731560024337 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_op_reduce_strided::ucc_pt_op_reduce_strided(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, int nbufs, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = false; has_reduction_ = true; has_range_ = true; has_bw_ = true; if (nbufs == UCC_PT_DEFAULT_N_BUFS) { nbufs = 2; } if (nbufs < 2) { throw std::runtime_error("dt reduce op requires at least 2 bufs"); } data_type = dt; mem_type = mt; reduce_op = op; num_bufs = nbufs; } ucc_status_t ucc_pt_op_reduce_strided::init_args(ucc_pt_test_args_t &test_args) { ucc_ee_executor_task_args_t &args = test_args.executor_args; size_t dt_size = ucc_dt_size(data_type); size_t size = generator->get_src_count() * dt_size; size_t stride = generator->get_src_count() * dt_size; ucc_status_t st; UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, size, mem_type), exit, st); UCCCHECK_GOTO(ucc_pt_alloc(&src_header, size * num_bufs, mem_type), free_dst, st); args.task_type = UCC_EE_EXECUTOR_TASK_REDUCE_STRIDED; args.reduce_strided.dst = dst_header->addr; args.reduce_strided.src1 = src_header->addr; args.reduce_strided.src2 = PTR_OFFSET(src_header->addr, size); args.reduce_strided.n_src2 = num_bufs - 1; args.reduce_strided.stride = stride; args.reduce_strided.count = generator->get_src_count(); args.reduce_strided.dt = data_type; args.reduce_strided.op = reduce_op; args.flags = 0; return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } float ucc_pt_op_reduce_strided::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_ee_executor_task_args_t &args = test_args.executor_args; float S = args.reduce_strided.count * ucc_dt_size(data_type); return (num_bufs + 1) * (S / time_ms) / 1000.0; } void ucc_pt_op_reduce_strided::free_args(ucc_pt_test_args_t &test_args) { ucc_pt_free(src_header); ucc_pt_free(dst_header); } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_scatterv.cc0000664000175000017500000000623215133731560023511 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_scatterv::ucc_pt_coll_scatterv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = false; root_shift_ = root_shift; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_SCATTERV; coll_args.src.info_v.datatype = dt; coll_args.src.info_v.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_scatterv::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.dst.info.datatype); ucc_status_t st; bool is_root; coll_args.root = test_args.coll_args.root; args = coll_args; is_root = (comm->get_rank() == args.root); if (is_root || root_shift_) { args.src.info_v.counts = generator->get_src_counts(); args.src.info_v.displacements = generator->get_src_displs(); UCCCHECK_GOTO( ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info_v.mem_type), exit, st); args.src.info_v.buffer = src_header->addr; } if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { args.dst.info.count = generator->get_dst_count(); st = ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info.mem_type); if (UCC_OK != st) { std::cerr << "UCC perftest error: " << ucc_status_string(st) << " in " << STR(_call) << "\n"; if (is_root || root_shift_) { goto free_src; } else { goto exit; } } args.dst.info.buffer = dst_header->addr; } return UCC_OK; free_src: ucc_pt_free(src_header); exit: return st; } void ucc_pt_coll_scatterv::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; bool is_root = (comm->get_rank() == args.root); if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { ucc_pt_free(dst_header); } if (is_root || root_shift_) { ucc_pt_free(src_header); } } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_config.cc0000664000175000017500000004725615133731560022125 0ustar alastairalastair/** * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_config.h" BEGIN_C_DECLS #include "utils/ucc_string.h" #include END_C_DECLS #include ucc_pt_config::ucc_pt_config() { bootstrap.bootstrap = UCC_PT_BOOTSTRAP_MPI; bench.op_type = UCC_PT_OP_TYPE_ALLREDUCE; bench.min_count = 128; bench.max_count = 128; bench.dt = UCC_DT_FLOAT32; bench.mt = UCC_MEMORY_TYPE_HOST; bench.op = UCC_OP_SUM; bench.inplace = false; bench.persistent = false; bench.map_type = UCC_PT_MAP_TYPE_NONE; bench.triggered = false; bench.n_iter_small = 1000; bench.n_warmup_small = 100; bench.n_iter_large = 200; bench.n_warmup_large = 20; bench.large_thresh = 64 * 1024; bench.full_print = false; bench.n_bufs = UCC_PT_DEFAULT_N_BUFS; bench.root = 0; bench.root_shift = 0; bench.mult_factor = 2; bench.seed = 0; comm.mt = bench.mt; } const std::map ucc_pt_reduction_op_map = { {"sum", UCC_OP_SUM}, {"prod", UCC_OP_PROD}, {"min", UCC_OP_MIN}, {"max", UCC_OP_MAX}, {"avg", UCC_OP_AVG}, }; const std::map ucc_pt_op_map = { {"allgather", UCC_PT_OP_TYPE_ALLGATHER}, {"allgatherv", UCC_PT_OP_TYPE_ALLGATHERV}, {"allreduce", UCC_PT_OP_TYPE_ALLREDUCE}, {"alltoall", UCC_PT_OP_TYPE_ALLTOALL}, {"alltoallv", UCC_PT_OP_TYPE_ALLTOALLV}, {"barrier", UCC_PT_OP_TYPE_BARRIER}, {"bcast", UCC_PT_OP_TYPE_BCAST}, {"gather", UCC_PT_OP_TYPE_GATHER}, {"gatherv", UCC_PT_OP_TYPE_GATHERV}, {"reduce", UCC_PT_OP_TYPE_REDUCE}, {"reduce_scatter", UCC_PT_OP_TYPE_REDUCE_SCATTER}, {"reduce_scatterv", UCC_PT_OP_TYPE_REDUCE_SCATTERV}, {"scatter", UCC_PT_OP_TYPE_SCATTER}, {"scatterv", UCC_PT_OP_TYPE_SCATTERV}, {"memcpy", UCC_PT_OP_TYPE_MEMCPY}, {"reducedt", UCC_PT_OP_TYPE_REDUCEDT}, {"reducedt_strided", UCC_PT_OP_TYPE_REDUCEDT_STRIDED}, }; const std::map ucc_pt_memtype_map = { {"host", UCC_MEMORY_TYPE_HOST}, {"cuda", UCC_MEMORY_TYPE_CUDA}, {"rocm", UCC_MEMORY_TYPE_ROCM}, {"cuda-mng", UCC_MEMORY_TYPE_CUDA_MANAGED}, }; const std::map ucc_pt_datatype_map = { {"int8", UCC_DT_INT8}, {"uint8", UCC_DT_UINT8}, {"int16", UCC_DT_INT16}, {"uint16", UCC_DT_UINT16}, {"float16", UCC_DT_FLOAT16}, {"bfloat16", UCC_DT_BFLOAT16}, {"int32", UCC_DT_INT32}, {"uint32", UCC_DT_UINT32}, {"float32", UCC_DT_FLOAT32}, {"float32_complex", UCC_DT_FLOAT32_COMPLEX}, {"int64", UCC_DT_INT64}, {"uint64", UCC_DT_UINT64}, {"float64", UCC_DT_FLOAT64}, {"float64_complex", UCC_DT_FLOAT64_COMPLEX}, {"int128", UCC_DT_INT128}, {"uint128", UCC_DT_UINT128}, {"float128", UCC_DT_FLOAT64}, {"float128_complex", UCC_DT_FLOAT128_COMPLEX}, }; const std::map ucc_pt_map_type_map = { {"none", UCC_PT_MAP_TYPE_NONE}, {"local", UCC_PT_MAP_TYPE_LOCAL}, {"global", UCC_PT_MAP_TYPE_GLOBAL}, }; ucc_status_t ucc_pt_config::process_args(int argc, char *argv[]) { int c; ucc_status_t st; int option_index = 0; static struct option long_options[] = { {"gen", required_argument, 0, 0}, {"seed", required_argument, 0, 0}, {0, 0, 0, 0}}; // Reset getopt state optind = 1; while (1) { c = getopt_long(argc, argv, "c:b:e:d:f:m:n:w:o:N:r:S:M:iphFT", long_options, &option_index); if (c == -1) break; if (c == 0) { // long option if (strcmp(long_options[option_index].name, "gen") == 0) { std::string gen_arg(optarg); if (gen_arg.rfind("exp:", 0) == 0) { bench.gen.type = UCC_PT_GEN_TYPE_EXP; auto min_pos = gen_arg.find("min=", 4); if (min_pos == std::string::npos) { std::cerr << "Invalid format for --gen exp:min=N[@max=M]" << std::endl; return UCC_ERR_INVALID_PARAM; } auto at_pos = gen_arg.find("@", min_pos); if (at_pos != std::string::npos) { auto max_pos = gen_arg.find("max=", at_pos); if (max_pos == std::string::npos) { std::cerr << "Invalid format for --gen exp:min=N@max=M" << std::endl; return UCC_ERR_INVALID_PARAM; } try { ucc_status_t st = ucc_str_to_memunits( gen_arg .substr(min_pos + 4, at_pos - (min_pos + 4)) .c_str(), (void *)&bench.gen.exp.min); if (st != UCC_OK) { std::cerr << "Failed to parse min value" << std::endl; return st; } st = ucc_str_to_memunits( gen_arg.substr(max_pos + 4).c_str(), (void *)&bench.gen.exp.max); if (st != UCC_OK) { std::cerr << "Failed to parse max value" << std::endl; return st; } } catch (const std::exception& e) { std::cerr << "Invalid values in --gen exp:min=N@max=M" << std::endl; return UCC_ERR_INVALID_PARAM; } } else { try { ucc_status_t st = ucc_str_to_memunits( gen_arg.substr(min_pos + 4).c_str(), (void *)&bench.gen.exp.min); if (st != UCC_OK) { std::cerr << "Failed to parse min value" << std::endl; return st; } bench.gen.exp.max = bench.gen.exp.min; } catch (const std::exception& e) { std::cerr << "Invalid value in --gen exp:min=N" << std::endl; return UCC_ERR_INVALID_PARAM; } } bench.min_count = bench.gen.exp.min; bench.max_count = bench.gen.exp.max; } else if (gen_arg.rfind("file:", 0) == 0) { bench.gen.type = UCC_PT_GEN_TYPE_FILE; auto name_pos = gen_arg.find("name=", 5); if (name_pos == std::string::npos) { std::cerr << "Invalid format for --gen file:name=filename[@nrep=N]" << std::endl; return UCC_ERR_INVALID_PARAM; } auto at_pos = gen_arg.find("@", name_pos); if (at_pos != std::string::npos) { bench.gen.file_name = gen_arg.substr(name_pos + 5, at_pos - (name_pos + 5)); auto nrep_str = gen_arg.substr(at_pos + 1); if (nrep_str.rfind("nrep=", 0) == 0) { try { bench.gen.nrep = std::stoull(nrep_str.substr(5)); } catch (const std::exception& e) { std::cerr << "Invalid nrep value in --gen file:name=filename@nrep=N" << std::endl; return UCC_ERR_INVALID_PARAM; } } else { std::cerr << "Invalid format for --gen file:name=filename@nrep=N" << std::endl; return UCC_ERR_INVALID_PARAM; } } else { bench.gen.file_name = gen_arg.substr(name_pos + 5); bench.gen.nrep = 1; // Default value if nrep is not specified } } else if (gen_arg.rfind("matrix:", 0) == 0) { bench.gen.type = UCC_PT_GEN_TYPE_TRAFFIC_MATRIX; /* Defaults (all parameters optional) */ bench.gen.nrep = 1; bench.gen.matrix.kind = 0; bench.gen.matrix.token_size_KB_mean = 16; bench.gen.matrix .token_size_KB_std = bench.gen.matrix .token_size_KB_mean; bench.gen.matrix.num_tokens = 2048; bench.gen.matrix.tgt_group_size_mean = 8; auto find_param = [&](const std::string &key, std::string &out) -> bool { auto pos = gen_arg.find(key + "="); if (pos == std::string::npos) { return false; } pos += key.size() + 1; auto end = gen_arg.find("@", pos); if (end == std::string::npos) { end = gen_arg.size(); } out = gen_arg.substr(pos, end - pos); return true; }; std::string kind_str; if (find_param("kind", kind_str)) { std::string ks = kind_str; std::transform( ks.begin(), ks.end(), ks.begin(), ::tolower); if (ks == "0" || ks.find("normal") != std::string::npos) { bench.gen.matrix.kind = 0; } else if ( ks == "1" || ks.find("biased") != std::string::npos) { bench.gen.matrix.kind = 1; } else if ( ks == "2" || ks.find("random_tgt_group") != std::string::npos) { bench.gen.matrix.kind = 2; } else if ( ks == "3" || ks.find("random_tgt_group_random_msg_size") != std::string::npos) { bench.gen.matrix.kind = 3; } else { std::cerr << "Invalid kind value in --gen matrix: " "accepts 0,1,2,3 or names {normal, " "biased, random_tgt_group, " "random_tgt_group_random_msg_size}" << std::endl; return UCC_ERR_INVALID_PARAM; } } std::string nrep_str; if (find_param("nrep", nrep_str)) { try { bench.gen.nrep = std::stoull(nrep_str); } catch (const std::exception &) { std::cerr << "Invalid nrep value in --gen matrix" << std::endl; return UCC_ERR_INVALID_PARAM; } } std::string token_size_str; if (find_param("token_size", token_size_str)) { try { bench.gen.matrix.token_size_KB_mean = std::stoull( token_size_str); } catch (const std::exception &) { std::cerr << "Invalid token_size value in --gen matrix" << std::endl; return UCC_ERR_INVALID_PARAM; } bench.gen.matrix .token_size_KB_std = bench.gen.matrix .token_size_KB_mean; } std::string num_tokens_str; if (find_param("num_tokens", num_tokens_str)) { try { bench.gen.matrix.num_tokens = std::stoull( num_tokens_str); } catch (const std::exception &) { std::cerr << "Invalid num_tokens value in --gen matrix" << std::endl; return UCC_ERR_INVALID_PARAM; } } std::string tgt_group_size_str; if (find_param("tgt_group_size", tgt_group_size_str)) { try { bench.gen.matrix.tgt_group_size_mean = std::stoull( tgt_group_size_str); } catch (const std::exception &) { std::cerr << "Invalid tgt_group_size value in " "--gen matrix" << std::endl; return UCC_ERR_INVALID_PARAM; } } } else { std::cerr << "Invalid value for --gen. Use exp:min=N[@max=M] or " "file:name=filename[@nrep=N] or " "matrix:kind=mat_kind[@nrep=N@token_size=M@num_" "tokens=K]" << std::endl; return UCC_ERR_INVALID_PARAM; } } else if (strcmp(long_options[option_index].name, "seed") == 0) { std::string seed_str(optarg); try { bench.seed = std::stoull(seed_str); } catch (const std::exception &) { std::cerr << "Invalid seed value in --seed" << std::endl; return UCC_ERR_INVALID_PARAM; } } else { std::cerr << "Unknown long option" << std::endl; return UCC_ERR_INVALID_PARAM; } continue; } switch (c) { case 'c': if (ucc_pt_op_map.count(optarg) == 0) { std::cerr << "invalid operation: " << optarg << std::endl; return UCC_ERR_INVALID_PARAM; } bench.op_type = ucc_pt_op_map.at(optarg); break; case 'o': if (ucc_pt_reduction_op_map.count(optarg) == 0) { std::cerr << "invalid reduction operation: " << optarg << std::endl; return UCC_ERR_INVALID_PARAM; } bench.op = ucc_pt_reduction_op_map.at(optarg); break; case 'm': if (ucc_pt_memtype_map.count(optarg) == 0) { std::cerr << "invalid memory type: " << optarg <> bench.root; break; case 'S': std::stringstream(optarg) >> bench.root_shift; break; case 'n': std::stringstream(optarg) >> bench.n_iter_small; bench.n_iter_large = bench.n_iter_small; break; case 'w': std::stringstream(optarg) >> bench.n_warmup_small; bench.n_warmup_large = bench.n_warmup_small; break; case 'f': std::stringstream(optarg) >> bench.mult_factor; break; case 'N': std::stringstream(optarg) >> bench.n_bufs; break; case 'i': bench.inplace = true; break; case 'p': bench.persistent = true; break; case 'M': if (ucc_pt_map_type_map.count(optarg) == 0) { std::cerr << "invalid map type: " << optarg << std::endl; return UCC_ERR_INVALID_PARAM; } bench.map_type = ucc_pt_map_type_map.at(optarg); break; case 'T': bench.triggered = true; break; case 'F': bench.full_print = true; break; case 'h': default: print_help(); std::exit(0); } } return UCC_OK; } void ucc_pt_config::print_help() { std::cout << "Usage: ucc_perftest [options]"<: Collective type"<: Min number of elements"<: Max number of elements"<: datatype"<: reduction operation type"<: root for rooted collectives"<: memory type"<: number of iterations"<: number of warmup iterations"<: multiplication factor between sizes. Default : 2."<: number of buffers"<: root shift for rooted collectives"<: Pattern generator (exponential or file-based or matrix-based)" << std::endl; std::cout << " --seed : seed for the random distributions" << std::endl; std::cout << " -h: show this help message"< extern "C" { #include #include } ucc_status_t ucc_pt_alloc(ucc_mc_buffer_header_t **h_ptr, size_t len, ucc_memory_type_t mem_type); ucc_status_t ucc_pt_free(ucc_mc_buffer_header_t *h_ptr); typedef union { ucc_coll_args_t coll_args; ucc_ee_executor_task_args_t executor_args; } ucc_pt_test_args_t; class ucc_pt_coll { protected: bool has_inplace_; bool has_reduction_; bool has_range_; bool has_bw_; int root_shift_; ucc_pt_comm *comm; ucc_pt_generator_base *generator; ucc_coll_args_t coll_args; ucc_ee_executor_task_args_t executor_args; ucc_mc_buffer_header_t *dst_header; ucc_mc_buffer_header_t *src_header; ucc_mem_map_mem_h src_memh; ucc_mem_map_mem_h dst_memh; ucc_mem_map_mem_h *dst_memh_global; ucc_mem_map_mem_h *src_memh_global; public: ucc_pt_coll(ucc_pt_comm *communicator, ucc_pt_generator_base *generator) { this->comm = communicator; this->generator = generator; src_header = nullptr; dst_header = nullptr; src_memh = nullptr; dst_memh = nullptr; dst_memh_global = nullptr; src_memh_global = nullptr; } virtual ucc_status_t init_args(ucc_pt_test_args_t &args) = 0; virtual void free_args(ucc_pt_test_args_t &args) {} virtual float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) { return 0.0; } bool has_reduction(); bool has_inplace(); bool has_range(); bool has_bw(); virtual ~ucc_pt_coll() {}; }; class ucc_pt_coll_allgather: public ucc_pt_coll { public: ucc_pt_coll_allgather(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_map_type_t map_type, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; ~ucc_pt_coll_allgather(); }; class ucc_pt_coll_allgatherv: public ucc_pt_coll { public: ucc_pt_coll_allgatherv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; class ucc_pt_coll_allreduce: public ucc_pt_coll { public: ucc_pt_coll_allreduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_coll_alltoall: public ucc_pt_coll { public: ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_map_type_t map_type, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; ~ucc_pt_coll_alltoall(); }; class ucc_pt_coll_alltoallv: public ucc_pt_coll { public: ucc_pt_coll_alltoallv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; ~ucc_pt_coll_alltoallv(); }; class ucc_pt_coll_barrier: public ucc_pt_coll { public: ucc_pt_coll_barrier(ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; class ucc_pt_coll_bcast: public ucc_pt_coll { public: ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, int root_shift, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_coll_gather: public ucc_pt_coll { public: ucc_pt_coll_gather(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_coll_gatherv: public ucc_pt_coll { public: ucc_pt_coll_gatherv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; class ucc_pt_coll_reduce: public ucc_pt_coll { public: ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_coll_reduce_scatter: public ucc_pt_coll { public: ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_coll_reduce_scatterv: public ucc_pt_coll { public: ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; class ucc_pt_coll_scatter: public ucc_pt_coll { public: ucc_pt_coll_scatter(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_coll_scatterv: public ucc_pt_coll { public: ucc_pt_coll_scatterv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; }; class ucc_pt_op_memcpy: public ucc_pt_coll { ucc_memory_type_t mem_type; ucc_datatype_t data_type; int num_bufs; public: ucc_pt_op_memcpy(ucc_datatype_t dt, ucc_memory_type mt, int nbufs, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_op_reduce: public ucc_pt_coll { ucc_memory_type_t mem_type; ucc_datatype_t data_type; ucc_reduction_op_t reduce_op; int num_bufs; public: ucc_pt_op_reduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, int nbufs, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; class ucc_pt_op_reduce_strided: public ucc_pt_coll { ucc_memory_type_t mem_type; ucc_datatype_t data_type; ucc_reduction_op_t reduce_op; int num_bufs; public: ucc_pt_op_reduce_strided(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, int nbufs, ucc_pt_comm *communicator, ucc_pt_generator_base *generator); ucc_status_t init_args(ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override; }; #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_reduce_scatter.cc0000664000175000017500000000641115133731560024651 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_reduce_scatter::ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = true; has_range_ = true; has_bw_ = true; root_shift_ = 0; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_reduce_scatter::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st; args = coll_args; src_header = nullptr; dst_header = nullptr; if (UCC_IS_INPLACE(args)) { args.src.info.count = 0; args.dst.info.count = generator->get_dst_count(); } else { args.src.info.count = generator->get_src_count(); args.dst.info.count = generator->get_dst_count(); } UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info.mem_type), exit, st); args.dst.info.buffer = dst_header->addr; if (args.src.info.count != 0) { UCCCHECK_GOTO(ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), free_dst, st); args.src.info.buffer = src_header->addr; } return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } void ucc_pt_coll_reduce_scatter::free_args(ucc_pt_test_args_t &test_args) { if (dst_header) { ucc_pt_free(dst_header); dst_header = nullptr; } if (src_header) { ucc_pt_free(src_header); src_header = nullptr; } } float ucc_pt_coll_reduce_scatter::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float N = grsize; size_t count = UCC_IS_INPLACE(args) ? args.dst.info.count : args.src.info.count; float S = count * ucc_dt_size(args.dst.info.datatype); return (S / time_ms) * ((N - 1) / N) / 1000.0; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll.cc0000664000175000017500000000517715133731560021605 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_pt_cuda.h" #include "utils/ucc_malloc.h" ucc_status_t ucc_pt_alloc(ucc_mc_buffer_header_t **h_ptr, size_t len, ucc_memory_type_t mem_type) { ucc_status_t status; int cuda_st; switch (mem_type) { case UCC_MEMORY_TYPE_CUDA: *h_ptr = new ucc_mc_buffer_header_t; (*h_ptr)->mt = UCC_MEMORY_TYPE_CUDA; cuda_st = ucc_pt_cudaMalloc(&((*h_ptr)->addr), len); if (cuda_st != 0) { return UCC_ERR_NO_MEMORY; } cuda_st = ucc_pt_cudaMemset((*h_ptr)->addr, 0, len); if (cuda_st != 0) { ucc_pt_cudaFree((*h_ptr)->addr); delete *h_ptr; return UCC_ERR_NO_MEMORY; } return UCC_OK; case UCC_MEMORY_TYPE_CUDA_MANAGED: *h_ptr = new ucc_mc_buffer_header_t; (*h_ptr)->mt = UCC_MEMORY_TYPE_CUDA_MANAGED; cuda_st = ucc_pt_cudaMallocManaged(&((*h_ptr)->addr), len); if (cuda_st != 0) { return UCC_ERR_NO_MEMORY; } cuda_st = ucc_pt_cudaMemset((*h_ptr)->addr, 0, len); if (cuda_st != 0) { ucc_pt_cudaFree((*h_ptr)->addr); delete *h_ptr; return UCC_ERR_NO_MEMORY; } return UCC_OK; case UCC_MEMORY_TYPE_HOST: *h_ptr = new ucc_mc_buffer_header_t; (*h_ptr)->mt = UCC_MEMORY_TYPE_HOST; (*h_ptr)->addr = ucc_malloc(len, "perftest data"); if (!((*h_ptr)->addr)) { return UCC_ERR_NO_MEMORY; } memset((*h_ptr)->addr, 0, len); return UCC_OK; default: break; } status = ucc_mc_alloc(h_ptr, len, mem_type); if (status != UCC_OK) { return status; } status = ucc_mc_memset((*h_ptr)->addr, 0, len, mem_type); if (status != UCC_OK) { ucc_mc_free(*h_ptr); return status; } return UCC_OK; } ucc_status_t ucc_pt_free(ucc_mc_buffer_header_t *h_ptr) { switch (h_ptr->mt) { case UCC_MEMORY_TYPE_CUDA: case UCC_MEMORY_TYPE_CUDA_MANAGED: ucc_pt_cudaFree(h_ptr->addr); delete h_ptr; return UCC_OK; case UCC_MEMORY_TYPE_HOST: ucc_free(h_ptr->addr); delete h_ptr; return UCC_OK; default: break; } return ucc_mc_free(h_ptr); } bool ucc_pt_coll::has_reduction() { return has_reduction_; } bool ucc_pt_coll::has_inplace() { return has_inplace_; } bool ucc_pt_coll::has_range() { return has_range_; } bool ucc_pt_coll::has_bw() { return has_bw_; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_reduce_scatterv.cc0000664000175000017500000000647015133731560025044 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_reduce_scatterv::ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = true; has_range_ = true; has_bw_ = false; root_shift_ = 0; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_REDUCE_SCATTERV; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_reduce_scatterv::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; int tsize = comm->get_size(); size_t dt_size = ucc_dt_size(coll_args.dst.info_v.datatype); ucc_count_t *counts; ucc_aint_t *displs; ucc_status_t st; size_t size_src, size_dst; args = coll_args; src_header = nullptr; dst_header = nullptr; args.dst.info_v.counts = nullptr; args.dst.info_v.displacements = nullptr; if (UCC_IS_INPLACE(args)) { size_src = 0; size_dst = tsize * generator->get_src_count() * dt_size; } else { size_src = tsize * generator->get_src_count() * dt_size; size_dst = generator->get_dst_count() * dt_size; } counts = generator->get_dst_counts(); displs = generator->get_dst_displs(); UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, size_dst, args.dst.info_v.mem_type), exit, st); args.dst.info_v.buffer = dst_header->addr; if (!UCC_IS_INPLACE(args)) { args.src.info.count = generator->get_src_count(); UCCCHECK_GOTO( ucc_pt_alloc(&src_header, size_src, args.src.info.mem_type), free_dst, st); args.src.info.buffer = src_header->addr; } args.dst.info_v.counts = counts; args.dst.info_v.displacements = displs; return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: src_header = nullptr; dst_header = nullptr; args.dst.info_v.counts = nullptr; args.dst.info_v.displacements = nullptr; return st; } void ucc_pt_coll_reduce_scatterv::free_args(ucc_pt_test_args_t &test_args) { if (dst_header) { ucc_pt_free(dst_header); dst_header = nullptr; } if (src_header) { ucc_pt_free(src_header); src_header = nullptr; } } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_bootstrap_mpi.h0000664000175000017500000000072715133731560023374 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_BOOTSTRAP_MPI_H #define UCC_PT_BOOTSTRAP_MPI_H #include #include "ucc_pt_bootstrap.h" class ucc_pt_bootstrap_mpi: public ucc_pt_bootstrap { public: ucc_pt_bootstrap_mpi(); ~ucc_pt_bootstrap_mpi(); int get_rank() override; int get_size() override; protected: int rank; int size; int ppn; }; #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_rocm.cc0000664000175000017500000000232615133731560021605 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (C) Advanced Micro Devices, Inc. 2022. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ #include "ucc_pt_rocm.h" #include #include #include ucc_pt_rocm_iface_t ucc_pt_rocm_iface = { .available = 0, }; #define LOAD_ROCM_SYM(_sym, _pt_sym) ({ \ void *h = dlsym(handle, _sym); \ if (dlerror() != NULL) { \ return; \ } \ ucc_pt_rocm_iface. _pt_sym = \ reinterpret_cast(h); \ }) void ucc_pt_rocm_init(void) { void *handle; handle = dlopen ("libamdhip64.so", RTLD_LAZY); if (!handle) { return; } LOAD_ROCM_SYM("hipGetDeviceCount", getDeviceCount); LOAD_ROCM_SYM("hipSetDevice", setDevice); LOAD_ROCM_SYM("hipGetErrorString", getErrorString); ucc_pt_rocm_iface.available = 1; } openucx-ucc-ec0bc8a/tools/perf/generator/0000775000175000017500000000000015133731560020764 5ustar alastairalastairopenucx-ucc-ec0bc8a/tools/perf/generator/ucc_pt_generator_file.cc0000664000175000017500000002277115133731560025626 0ustar alastairalastair#include "ucc_pt_generator.h" #include "utils/ini.h" #include #include #include #include ucc_pt_generator_file::ucc_pt_generator_file(const std::string &file_path, uint32_t gsize, uint32_t rank, ucc_pt_op_type_t type, size_t nrepeats) { input_file = file_path; comm_size = gsize; rank_id = rank; op_type = type; current_pattern = 0; current_rep = 0; nrep = nrepeats; // Open and validate the INI file FILE* file = fopen(file_path.c_str(), "r"); if (!file) { throw std::runtime_error("Failed to open pattern file: " + file_path); } struct counts_state_t { std::string counts_accum; bool in_counts = false; } counts_state; ini_handler handler = [](void* user, const char* section, const char* name, const char* value) -> int { auto* self = static_cast(user); auto* state = static_cast(self->counts_state_ptr); if (strcmp(section, "collective") == 0 && strcmp(name, "type") == 0) { if (state->in_counts && !state->counts_accum.empty()) { std::stringstream ss(state->counts_accum); std::string item; std::vector pattern; while (std::getline(ss, item, ',')) { item.erase(0, item.find_first_not_of(" \t\n\r")); item.erase(item.find_last_not_of(" \t\n\r") + 1); if (!item.empty()) { pattern.push_back(std::stoull(item)); } } self->pattern_counts.push_back(pattern); state->counts_accum.clear(); state->in_counts = false; } if (strcmp(value, "ALLTOALLV") != 0) { throw std::runtime_error("Unsupported collective type: " + std::string(value) + ". Only ALLTOALLV is supported."); } } else if (strcmp(section, "collective") == 0 && strcmp(name, "counts") == 0) { std::string line = value; line.erase(0, line.find_first_not_of(" \t\n\r")); line.erase(line.find_last_not_of(" \t\n\r") + 1); size_t open = line.find('{'); if (open != std::string::npos) { state->in_counts = true; state->counts_accum.clear(); line = line.substr(open + 1); line.erase(0, line.find_first_not_of(" \t\n\r")); line.erase(line.find_last_not_of(" \t\n\r") + 1); } size_t close = line.find('}'); if (close != std::string::npos) { std::string before_brace = line.substr(0, close); before_brace.erase(0, before_brace.find_first_not_of(" \t\n\r")); before_brace.erase(before_brace.find_last_not_of(" \t\n\r") + 1); if (!before_brace.empty()) { if (!state->counts_accum.empty()) state->counts_accum += ","; state->counts_accum += before_brace; } if (state->in_counts && !state->counts_accum.empty()) { std::stringstream ss(state->counts_accum); std::string item; std::vector pattern; while (std::getline(ss, item, ',')) { item.erase(0, item.find_first_not_of(" \t\n\r")); item.erase(item.find_last_not_of(" \t\n\r") + 1); if (!item.empty()) { pattern.push_back(std::stoull(item)); } } self->pattern_counts.push_back(pattern); state->counts_accum.clear(); } state->in_counts = false; return 1; } if (!line.empty()) { if (!state->counts_accum.empty()) state->counts_accum += ","; state->counts_accum += line; } } return 1; }; this->counts_state_ptr = &counts_state; if (ucc_ini_parse_file(file, handler, this) < 0) { fclose(file); throw std::runtime_error("Failed to parse pattern file: " + file_path); } this->counts_state_ptr = nullptr; fclose(file); // After parsing, if there is any accumulated pattern, push it if (!counts_state.counts_accum.empty()) { std::stringstream ss(counts_state.counts_accum); std::string item; std::vector pattern; while (std::getline(ss, item, ',')) { item.erase(0, item.find_first_not_of(" \t\n\r")); item.erase(item.find_last_not_of(" \t\n\r") + 1); if (!item.empty()) { pattern.push_back(std::stoull(item)); } } pattern_counts.push_back(pattern); } if (pattern_counts.empty()) { throw std::runtime_error("No collective patterns found in file: " + file_path); } for (size_t i = 0; i < pattern_counts.size(); i++) { if (pattern_counts[i].size() != comm_size * comm_size) { throw std::runtime_error("Pattern size (" + std::to_string(pattern_counts[i].size()) + ") is not equal to comm_size*comm_size (" + std::to_string(comm_size * comm_size) + "). " "Please check the pattern file or the comm_size."); } } // Initialize arrays for counts and displacements src_counts.resize(comm_size); src_displs.resize(comm_size); dst_counts.resize(comm_size); dst_displs.resize(comm_size); } bool ucc_pt_generator_file::has_next() { return current_rep < nrep; } void ucc_pt_generator_file::next() { current_pattern++; if (current_pattern >= pattern_counts.size()) { current_pattern = 0; current_rep++; } if (has_next()) { setup_counts_displs(); } } void ucc_pt_generator_file::reset() { current_pattern = 0; current_rep = 0; setup_counts_displs(); } size_t ucc_pt_generator_file::get_src_count() { size_t total = 0; for (int i = 0; i < comm_size; i++) { total += src_counts[i]; } return total; } size_t ucc_pt_generator_file::get_dst_count() { size_t total = 0; for (int i = 0; i < comm_size; i++) { total += dst_counts[i]; } return total; } ucc_count_t *ucc_pt_generator_file::get_src_counts() { return (ucc_count_t *)src_counts.data(); } ucc_aint_t *ucc_pt_generator_file::get_src_displs() { return (ucc_aint_t *)src_displs.data(); } ucc_count_t *ucc_pt_generator_file::get_dst_counts() { return (ucc_count_t *)dst_counts.data(); } ucc_aint_t *ucc_pt_generator_file::get_dst_displs() { return (ucc_aint_t *)dst_displs.data(); } void ucc_pt_generator_file::setup_counts_displs() { const auto& counts = pattern_counts[current_pattern]; if (counts.size() < comm_size * comm_size) { throw std::runtime_error("Pattern size (" + std::to_string(counts.size()) + ") is less than comm_size*comm_size (" + std::to_string(comm_size * comm_size) + ")"); } for (int i = 0; i < comm_size; i++) { src_counts[i] = counts[rank_id * comm_size + i]; } size_t displ = 0; for (int i = 0; i < comm_size; i++) { src_displs[i] = displ; displ += src_counts[i]; } for (int i = 0; i < comm_size; i++) { dst_counts[i] = counts[i * comm_size + rank_id]; } displ = 0; for (int i = 0; i < comm_size; i++) { dst_displs[i] = displ; displ += dst_counts[i]; } } size_t ucc_pt_generator_file::get_src_count_max() { size_t max_src_count = 0; for (size_t i = 0; i < pattern_counts.size(); i++) { const auto& counts = pattern_counts[i]; size_t total_src = 0; for (int j = 0; j < comm_size; j++) { total_src += counts[rank_id * comm_size + j]; } if (total_src > max_src_count) { max_src_count = total_src; } } return max_src_count; } size_t ucc_pt_generator_file::get_dst_count_max() { size_t max_dst_count = 0; for (size_t i = 0; i < pattern_counts.size(); i++) { const auto& counts = pattern_counts[i]; size_t total_dst = 0; for (int j = 0; j < comm_size; j++) { total_dst += counts[j * comm_size + rank_id]; } if (total_dst > max_dst_count) { max_dst_count = total_dst; } } return max_dst_count; } size_t ucc_pt_generator_file::get_count_max() { const auto &matrix = pattern_counts[current_pattern]; size_t max_count = 0; size_t cur_row_col; for (int i = 0; i < comm_size; i++) { cur_row_col = 0; for (int j = 0; j < comm_size; j++) { cur_row_col += matrix[i * comm_size + j]; } if (cur_row_col > max_count) { max_count = cur_row_col; } } for (int i = 0; i < comm_size; i++) { cur_row_col = 0; for (int j = 0; j < comm_size; j++) { cur_row_col += matrix[j * comm_size + i]; } if (cur_row_col > max_count) { max_count = cur_row_col; } } return max_count; }openucx-ucc-ec0bc8a/tools/perf/generator/ucc_pt_generator.h0000664000175000017500000001172215133731560024463 0ustar alastairalastair/** * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_GENERATOR_H #define UCC_PT_GENERATOR_H #include "ucc_pt_config.h" #include #include #include #include #include class ucc_pt_generator_base { public: virtual bool has_next() = 0; virtual void next() = 0; virtual size_t get_src_count() = 0; // src buffer count virtual size_t get_dst_count() = 0; // dst buffer count virtual size_t get_count_max() = 0; // max (src_count, dst_count) across all ranks virtual ucc_count_t *get_src_counts() = 0; virtual ucc_aint_t *get_src_displs() = 0; virtual ucc_count_t *get_dst_counts() = 0; virtual ucc_aint_t *get_dst_displs() = 0; virtual size_t get_src_count_max() = 0; // max src buffer count across iterations virtual size_t get_dst_count_max() = 0; // max dst buffer count across iterations virtual void reset() = 0; virtual ~ucc_pt_generator_base() {} }; class ucc_pt_generator_exponential : public ucc_pt_generator_base { private: uint32_t comm_size; size_t min_count; size_t max_count; size_t mult_factor; size_t current_count; std::vector src_counts; std::vector src_displs; std::vector dst_counts; std::vector dst_displs; ucc_pt_op_type_t op_type; public: ucc_pt_generator_exponential(size_t min, size_t max, size_t factor, uint32_t gsize, ucc_pt_op_type_t type); bool has_next() override; void next() override; void reset() override; size_t get_src_count() override; size_t get_dst_count() override; ucc_count_t *get_src_counts() override; ucc_aint_t *get_src_displs() override; ucc_count_t *get_dst_counts() override; ucc_aint_t *get_dst_displs() override; size_t get_src_count_max() override; size_t get_dst_count_max() override; size_t get_count_max() override; }; class ucc_pt_generator_file : public ucc_pt_generator_base { private: uint32_t comm_size; uint32_t rank_id; std::string input_file; size_t nrep; size_t current_pattern; size_t current_rep; std::vector> pattern_counts; // Store counts for each pattern std::vector src_counts; std::vector src_displs; std::vector dst_counts; std::vector dst_displs; ucc_pt_op_type_t op_type; void* counts_state_ptr = nullptr; void setup_counts_displs(); public: ucc_pt_generator_file(const std::string &file_path, uint32_t gsize, uint32_t rank, ucc_pt_op_type_t type, size_t nrep); bool has_next() override; void next() override; void reset() override; size_t get_src_count() override; size_t get_dst_count() override; ucc_count_t *get_src_counts() override; ucc_aint_t *get_src_displs() override; ucc_count_t *get_dst_counts() override; ucc_aint_t *get_dst_displs() override; size_t get_src_count_max() override; size_t get_dst_count_max() override; size_t get_count_max() override; }; class ucc_pt_generator_traffic_matrix : public ucc_pt_generator_base { private: uint32_t comm_size; uint32_t rank_id; int kind; int token_size_KB_mean; int tgt_group_size_mean; int num_tokens; int tgt_group_size_std; int token_size_KB_std; int num_hl_ranks; double bias_factor; size_t nrep; size_t current_pattern; size_t current_rep; size_t dt_size; std::mt19937_64 rng_; std::vector> pattern_counts; // Store counts for each pattern. vector of vectors of counts (#vectors = #matrices) std::vector src_counts; std::vector src_displs; std::vector dst_counts; std::vector dst_displs; std::vector> traffic_matrix; void *counts_state_ptr = nullptr; void setup_counts_displs(); ucc_pt_op_type_t op_type; public: ucc_pt_generator_traffic_matrix( int kind, uint32_t gsize, uint32_t rank, ucc_datatype_t dtype, ucc_pt_op_type_t type, size_t nrep, int token_size_KB_mean, int num_tokens, int tgt_group_size_mean, uint64_t seed); bool has_next() override; void next() override; void reset() override; size_t get_src_count() override; size_t get_dst_count() override; ucc_count_t *get_src_counts() override; ucc_aint_t *get_src_displs() override; ucc_count_t *get_dst_counts() override; ucc_aint_t *get_dst_displs() override; size_t get_src_count_max() override; size_t get_dst_count_max() override; size_t get_count_max() override; }; #endif openucx-ucc-ec0bc8a/tools/perf/generator/ucc_pt_generator_traffic_matrix.cc0000664000175000017500000004013115133731560027677 0ustar alastairalastair#include "ucc_pt_generator.h" #include #include #include #include #include #include #include #include #include #include #include template void random_choice( const std::vector &data, size_t size, std::vector &result_vec, std::mt19937_64 &rng, const std::vector &weights = {}) { if (data.empty()) { throw std::runtime_error("Cannot pick from an empty vector."); } result_vec.clear(); result_vec.reserve(size); std::vector final_weights; size_t N = data.size(); if (!weights.empty()) { if (weights.size() != N) { throw std::runtime_error( "Weights vector size must match data vector size."); } final_weights = weights; } else { final_weights.assign(N, 1.0); } std::discrete_distribution distribution( final_weights.begin(), final_weights.end()); for (size_t i = 0; i < size; ++i) { int index = distribution(rng); result_vec.push_back(data.at(index)); } return; } std::vector> create_a2aV_traffic_matrix( int num_ranks, int token_size_KB_mean, int tgt_group_size_mean, int num_tokens, size_t dt_size, std::mt19937_64 &rng, bool add_bias = false, double bias_factor = 2, int num_hl_ranks = 2) { // Create a random a2aV traffic matrix where each rank sends token_size_KB_mean messages // to a random group of tgt_group_size_mean other ranks. // If add_bias is true, the bias_factor is used to increase the probability of sending messages to the higher level // ranks. If num_hl_ranks is greater than 0, the num_hl_ranks highest level ranks will be used to send messages to // the lower level ranks. The traffic matrix is returned as a matrix of size num_ranks x num_ranks. std::vector bias_indices(num_hl_ranks); std::vector possible_targets(num_ranks); std::vector> traffic_matrix( num_ranks, std::vector(num_ranks, 0)); // matrix of size num_ranks x num_ranks for (int i = 0; i < num_hl_ranks; i++) { bias_indices[i] = std::uniform_int_distribution( 0, num_ranks - 1)(rng); } for (int src_rank = 0; src_rank < num_ranks; src_rank++) { // Choose random target ranks, excluding self possible_targets.clear(); for (int i = 0; i < num_ranks; i++) { if (i != src_rank) { possible_targets.push_back(i); } } for (int token = 0; token < num_tokens; token++) { std::vector target_ranks(tgt_group_size_mean); if (add_bias) { // Create biased probabilities for target selection std::vector probabilities( possible_targets.size(), 1.0 / possible_targets.size()); for (int i = 0; i < num_hl_ranks; i++) { int bias_rank = bias_indices[i]; auto it = std::find( possible_targets.begin(), possible_targets.end(), bias_rank); if (it != possible_targets.end()) { int bias_index = std::distance( possible_targets.begin(), it); probabilities[bias_index] *= bias_factor; } } double sum = std::accumulate( probabilities.begin(), probabilities.end(), 0.0); for (int i = 0; i < probabilities.size(); i++) { probabilities[i] = probabilities[i] / sum; } random_choice( possible_targets, tgt_group_size_mean, target_ranks, rng, probabilities); } else { random_choice( possible_targets, tgt_group_size_mean, target_ranks, rng); } for (int i = 0; i < target_ranks.size(); i++) { int target_rank = target_ranks[i]; traffic_matrix [src_rank] [target_rank] += (token_size_KB_mean * (1000 / dt_size)); } } } return traffic_matrix; } std::vector> create_random_tgt_group_a2aV_traffic_matrix( int num_ranks, int token_size_KB_mean, int tgt_group_size_mean, int tgt_group_size_std, int num_tokens, size_t dt_size, std::mt19937_64 &rng) { // Create a random a2aV traffic matrix where each rank sends token_size_KB_mean messages // to a random group of ranks with random size. // The traffic matrix is returned as a matrix of size num_ranks x num_ranks. std::vector> traffic_matrix( num_ranks, std::vector(num_ranks, 0)); for (int src_rank = 0; src_rank < num_ranks; src_rank++) { std::vector possible_targets; possible_targets.reserve(num_ranks - 1); for (int i = 0; i < num_ranks; i++) { if (i != src_rank) { possible_targets.push_back(i); } } for (int token = 0; token < num_tokens; token++) { std::normal_distribution distribution( tgt_group_size_mean, tgt_group_size_std); double normal_sample = distribution(rng); int tgt_group_size; tgt_group_size = std::max(1, static_cast(normal_sample)); tgt_group_size = std::min( tgt_group_size, num_ranks - 1); // Cap at available targets std::vector target_ranks(tgt_group_size); random_choice(possible_targets, tgt_group_size, target_ranks, rng); for (int i = 0; i < target_ranks.size(); i++) { traffic_matrix [src_rank] [target_ranks [i]] += (token_size_KB_mean * (1000 / dt_size)); } } } return traffic_matrix; } std::vector> create_random_tgt_group_random_msg_size_a2aV_traffic_matrix( int num_ranks, int token_size_KB_mean, int token_size_KB_std, int tgt_group_size_mean, int tgt_group_size_std, int num_tokens, size_t dt_size, std::mt19937_64 &rng) { // Create a random a2aV traffic matrix where each rank sends a random message size // to a random group of tgt_group_size_mean other ranks. // The traffic matrix is returned as a matrix of size num_ranks x num_ranks. std::vector> traffic_matrix( num_ranks, std::vector(num_ranks, 0)); std::normal_distribution distribution_tgt_group_size( tgt_group_size_mean, tgt_group_size_std); std::normal_distribution distribution_token_size( token_size_KB_mean, token_size_KB_std); for (int src_rank = 0; src_rank < num_ranks; src_rank++) { std::vector possible_targets(num_ranks - 1); possible_targets.clear(); for (int i = 0; i < num_ranks; i++) { if (i != src_rank) { possible_targets.push_back(i); } } for (int token = 0; token < num_tokens; token++) { int normal_sample_tgt_group_size = static_cast( distribution_tgt_group_size(rng)); int tgt_group_size; tgt_group_size = std::max(1, normal_sample_tgt_group_size); tgt_group_size = std::min( tgt_group_size, num_ranks - 1); // Cap at available targets std::vector target_ranks(tgt_group_size); random_choice(possible_targets, tgt_group_size, target_ranks, rng); for (int i = 0; i < target_ranks.size(); i++) { int normal_sample_token_size = static_cast( distribution_token_size(rng)); traffic_matrix [src_rank] [target_ranks[i]] += std::max(0, normal_sample_token_size) * (1000 / dt_size); } } } return traffic_matrix; } void print_result( std::vector> traffic_matrix, bool print_full_result) { // Print the traffic matrix // The first dimension is the source rank. // The second dimension is the target rank. for (int src_rank = 0; src_rank < traffic_matrix.size(); src_rank++) { for (int tgt_rank = 0; tgt_rank < traffic_matrix[0].size(); tgt_rank++) { std::cout << traffic_matrix[src_rank][tgt_rank] << " "; } std::cout << std::endl; } } ucc_pt_generator_traffic_matrix::ucc_pt_generator_traffic_matrix( int kind, uint32_t gsize, uint32_t rank, ucc_datatype_t dtype, ucc_pt_op_type_t type, size_t nrepeats, int token_size_KB_mean_, int num_tokens_, int tgt_group_size_mean_, uint64_t seed) { comm_size = gsize; rank_id = rank; op_type = type; current_pattern = 0; current_rep = 0; nrep = nrepeats; rng_ = std::mt19937_64(seed); token_size_KB_mean = token_size_KB_mean_; tgt_group_size_mean = tgt_group_size_mean_; num_tokens = num_tokens_; bias_factor = 2; num_hl_ranks = 2; tgt_group_size_std = 1; token_size_KB_std = 1; dt_size = ucc_dt_size(dtype); if (kind == 0) { traffic_matrix = create_a2aV_traffic_matrix( comm_size, token_size_KB_mean, tgt_group_size_mean, num_tokens, dt_size, rng_, false, bias_factor, num_hl_ranks); } else if (kind == 1) { traffic_matrix = create_a2aV_traffic_matrix( comm_size, token_size_KB_mean, tgt_group_size_mean, num_tokens, dt_size, rng_, true, bias_factor, num_hl_ranks); } else if (kind == 2) { traffic_matrix = create_random_tgt_group_a2aV_traffic_matrix( comm_size, token_size_KB_mean, tgt_group_size_mean, tgt_group_size_std, num_tokens, dt_size, rng_); } else if (kind == 3) { traffic_matrix = create_random_tgt_group_random_msg_size_a2aV_traffic_matrix( comm_size, token_size_KB_mean, token_size_KB_std, tgt_group_size_mean, tgt_group_size_std, num_tokens, dt_size, rng_); } // print_result(traffic_matrix, false); pattern_counts.reserve(traffic_matrix.size()); std::vector pattern; if (!traffic_matrix.empty()) { pattern.reserve(comm_size * comm_size); } if (traffic_matrix[0].size() != comm_size || traffic_matrix.size() != comm_size) { throw std::runtime_error( "Matrix size (" + std::to_string(traffic_matrix[0].size()) + "x" + std::to_string(traffic_matrix.size()) + ") is not equal to comm_size*comm_size (" + std::to_string(comm_size * comm_size) + "). " "Please check the traffic_matrix."); } for (const auto &row : traffic_matrix) { pattern.insert(pattern.end(), row.begin(), row.end()); } pattern_counts.push_back(pattern); if (pattern_counts.empty()) { throw std::runtime_error( "No collective patterns provided in traffic_matrix."); } // Initialize arrays for counts and displacements src_counts.resize(comm_size); src_displs.resize(comm_size); dst_counts.resize(comm_size); dst_displs.resize(comm_size); } bool ucc_pt_generator_traffic_matrix::has_next() { return current_rep < nrep; } void ucc_pt_generator_traffic_matrix::next() { current_pattern++; if (current_pattern >= pattern_counts.size()) { current_pattern = 0; current_rep++; } if (has_next()) { setup_counts_displs(); } } void ucc_pt_generator_traffic_matrix::reset() { current_pattern = 0; current_rep = 0; setup_counts_displs(); } size_t ucc_pt_generator_traffic_matrix::get_src_count() { size_t total = 0; for (int i = 0; i < comm_size; i++) { total += src_counts[i]; } return total; } size_t ucc_pt_generator_traffic_matrix::get_dst_count() { size_t total = 0; for (int i = 0; i < comm_size; i++) { total += dst_counts[i]; } return total; } ucc_count_t *ucc_pt_generator_traffic_matrix::get_src_counts() { return (ucc_count_t *)src_counts.data(); } ucc_aint_t *ucc_pt_generator_traffic_matrix::get_src_displs() { return (ucc_aint_t *)src_displs.data(); } ucc_count_t *ucc_pt_generator_traffic_matrix::get_dst_counts() { return (ucc_count_t *)dst_counts.data(); } ucc_aint_t *ucc_pt_generator_traffic_matrix::get_dst_displs() { return (ucc_aint_t *)dst_displs.data(); } void ucc_pt_generator_traffic_matrix::setup_counts_displs() { const auto &counts = pattern_counts[current_pattern]; if (counts.size() < comm_size * comm_size) { throw std::runtime_error( "Pattern size (" + std::to_string(counts.size()) + ") is less than comm_size*comm_size (" + std::to_string(comm_size * comm_size) + ")"); } for (int i = 0; i < comm_size; i++) { src_counts[i] = counts[rank_id * comm_size + i]; } size_t displ = 0; std::unordered_map size_to_displ; for (int i = 0; i < comm_size; i++) { const uint32_t msg_size = src_counts[i]; if (msg_size == 0) { src_displs[i] = 0; continue; } auto current_displ = size_to_displ.find(msg_size); if (current_displ != size_to_displ.end()) { // Reuse the same buffer region for equal-sized messages src_displs[i] = current_displ->second; } else { src_displs[i] = displ; size_to_displ[msg_size] = displ; displ += msg_size; } } for (int i = 0; i < comm_size; i++) { dst_counts[i] = counts[i * comm_size + rank_id]; } displ = 0; for (int i = 0; i < comm_size; i++) { dst_displs[i] = displ; displ += dst_counts[i]; } } size_t ucc_pt_generator_traffic_matrix::get_src_count_max() { size_t max_src_count = 0; for (size_t i = 0; i < pattern_counts.size(); i++) { const auto &counts = pattern_counts[i]; size_t total_src = 0; for (int j = 0; j < comm_size; j++) { total_src += counts[rank_id * comm_size + j]; } if (total_src > max_src_count) { max_src_count = total_src; } } return max_src_count; } size_t ucc_pt_generator_traffic_matrix::get_dst_count_max() { size_t max_dst_count = 0; for (size_t i = 0; i < pattern_counts.size(); i++) { const auto &counts = pattern_counts[i]; size_t total_dst = 0; for (int j = 0; j < comm_size; j++) { total_dst += counts[j * comm_size + rank_id]; } if (total_dst > max_dst_count) { max_dst_count = total_dst; } } return max_dst_count; } size_t ucc_pt_generator_traffic_matrix::get_count_max() { const auto &matrix = pattern_counts[current_pattern]; size_t max_count = 0; size_t cur_row_col; for (int i = 0; i < comm_size; i++) { cur_row_col = 0; for (int j = 0; j < comm_size; j++) { cur_row_col += matrix[i * comm_size + j]; } if (cur_row_col > max_count) { max_count = cur_row_col; } } for (int i = 0; i < comm_size; i++) { cur_row_col = 0; for (int j = 0; j < comm_size; j++) { cur_row_col += matrix[j * comm_size + i]; } if (cur_row_col > max_count) { max_count = cur_row_col; } } return max_count; } openucx-ucc-ec0bc8a/tools/perf/generator/ucc_pt_generator_exp.cc0000664000175000017500000002044215133731560025474 0ustar alastairalastair#include "ucc_pt_generator.h" ucc_pt_generator_exponential::ucc_pt_generator_exponential(size_t min, size_t max, size_t factor, uint32_t gsize, ucc_pt_op_type_t type) { min_count = min; max_count = max; mult_factor = factor; current_count = min; comm_size = gsize; op_type = type; } bool ucc_pt_generator_exponential::has_next() { return current_count <= max_count; } void ucc_pt_generator_exponential::next() { if (!has_next()) { return; } current_count *= mult_factor; } void ucc_pt_generator_exponential::reset() { current_count = min_count; } size_t ucc_pt_generator_exponential::get_src_count() { switch (op_type) { case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: return current_count; case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_ALLTOALLV: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_SCATTERV: return current_count * comm_size; case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: return 0; default: throw std::runtime_error("Operation type not supported"); } } size_t ucc_pt_generator_exponential::get_dst_count() { switch (op_type) { case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_ALLTOALLV: return current_count * comm_size; case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_SCATTERV: return current_count; case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: return 0; default: throw std::runtime_error("Operation type not supported"); } } size_t *ucc_pt_generator_exponential::get_src_counts() { switch (op_type) { case UCC_PT_OP_TYPE_ALLTOALLV: case UCC_PT_OP_TYPE_SCATTERV: src_counts = std::vector(comm_size, current_count); return (ucc_count_t *)src_counts.data(); case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: default: throw std::runtime_error("Operation type not supported"); } } size_t *ucc_pt_generator_exponential::get_src_displs() { switch (op_type) { case UCC_PT_OP_TYPE_ALLTOALLV: case UCC_PT_OP_TYPE_SCATTERV: src_displs = std::vector(comm_size, 0); for (size_t i = 0; i < comm_size; i++) { src_displs[i] = current_count * i; } return (ucc_aint_t *)src_displs.data(); case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: default: throw std::runtime_error("Operation type not supported"); } } size_t *ucc_pt_generator_exponential::get_dst_counts() { switch (op_type) { case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_ALLTOALLV: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: dst_counts = std::vector(comm_size, current_count); return (ucc_count_t *)dst_counts.data(); case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: case UCC_PT_OP_TYPE_SCATTERV: default: throw std::runtime_error("Operation type not supported"); } } size_t *ucc_pt_generator_exponential::get_dst_displs() { switch (op_type) { case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_ALLTOALLV: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: dst_displs = std::vector(comm_size, 0); for (size_t i = 0; i < comm_size; i++) { dst_displs[i] = current_count * i; } return (ucc_aint_t *)dst_displs.data(); case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: case UCC_PT_OP_TYPE_SCATTERV: default: throw std::runtime_error("Operation type not supported"); } } size_t ucc_pt_generator_exponential::get_src_count_max() { switch (op_type) { case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: return max_count; case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_ALLTOALLV: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_SCATTERV: return max_count * comm_size; case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: return 0; default: throw std::runtime_error("Operation type not supported"); } } size_t ucc_pt_generator_exponential::get_dst_count_max() { switch (op_type) { case UCC_PT_OP_TYPE_ALLGATHER: case UCC_PT_OP_TYPE_ALLGATHERV: case UCC_PT_OP_TYPE_GATHER: case UCC_PT_OP_TYPE_GATHERV: case UCC_PT_OP_TYPE_ALLTOALL: case UCC_PT_OP_TYPE_ALLTOALLV: return max_count * comm_size; case UCC_PT_OP_TYPE_ALLREDUCE: case UCC_PT_OP_TYPE_REDUCE: case UCC_PT_OP_TYPE_MEMCPY: case UCC_PT_OP_TYPE_REDUCEDT: case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: case UCC_PT_OP_TYPE_REDUCE_SCATTER: case UCC_PT_OP_TYPE_REDUCE_SCATTERV: case UCC_PT_OP_TYPE_SCATTER: case UCC_PT_OP_TYPE_SCATTERV: return max_count; case UCC_PT_OP_TYPE_BARRIER: case UCC_PT_OP_TYPE_BCAST: case UCC_PT_OP_TYPE_FANIN: case UCC_PT_OP_TYPE_FANOUT: return 0; default: throw std::runtime_error("Operation type not supported"); } } size_t ucc_pt_generator_exponential::get_count_max() { return std::max(get_src_count(), get_dst_count()); }openucx-ucc-ec0bc8a/tools/perf/generator/examples/0000775000175000017500000000000015133731560022602 5ustar alastairalastairopenucx-ucc-ec0bc8a/tools/perf/generator/examples/alltoallv_4.ptini0000664000175000017500000000063615133731560026071 0ustar alastairalastair[format] version = 1.0 description = UCC Collective Pattern File [collective] type = ALLTOALLV team_size = 4 datatype = INT32 flags = COUNT_32BIT counts = { 4902912,2774016,1523,34623 6607872,1327104,523,12356 52134,63422,1777,22356 1234,1234,1234,1234 } [collective] type = ALLTOALLV team_size = 4 datatype = INT32 flags = COUNT_32BIT counts = { 1,1,1,1 2,2,2,2 3,3,3,3 4,4,4,4 }openucx-ucc-ec0bc8a/tools/perf/generator/examples/alltoallv_16.ptini0000664000175000017500000000557615133731560026164 0ustar alastairalastair[format] version = 1.0 description = UCC Collective Pattern File [collective] type = ALLTOALLV team_size = 16 datatype = INT32 flags = COUNT_32BIT counts = { 4902912,2774016,3142656,6534144,5953536,4147200,1677312,2156544,4617216,7898112,1271808,7907328,5935104,1852416,3538944,3585024 6607872,1327104,2055168,7861248,2967552,6976512,2138112,6294528,7179264,3161088,3723264,3548160,3843072,6478848,506880,3022848 4331520,1907712,4866048,5336064,8285184,4792320,1419264,2608128,5428224,5382144,3244032,4856832,6681600,3972096,1548288,3474432 4276224,2230272,3981312,5815296,7151616,4414464,1363968,2110464,6976512,4939776,4055040,5280768,4091904,2801664,2976768,4368384 3363840,2267136,2386944,3972096,8193024,4442112,2571264,4967424,5354496,6294528,4478976,3741696,5907456,2589696,4856832,3234816 7446528,857088,4967424,3557376,1843200,4202496,1981440,9234432,7059456,2617344,4856832,2617344,2396160,6018048,6976512,1428480 4469760,2027520,4119552,5594112,8220672,4405248,1262592,1944576,5308416,5925888,3760128,6156288,6359040,4091904,1290240,3631104 5289984,1115136,2211840,5566464,4267008,6008832,2221056,8386560,6893568,3585024,3649536,5584896,3944448,5548032,1723392,4654080 6644736,1502208,4202496,5833728,2386944,5446656,2949120,6792192,5879808,3815424,3234816,4515840,3612672,6349824,2294784,3133440 4294656,3087360,2903040,4405248,6359040,8211456,2608128,3059712,3225600,8773632,2036736,3852288,2884608,5677056,949248,8137728 4405248,3050496,3456000,5769216,3631104,4801536,2681856,5492736,3621888,6451200,3806208,4764672,10506240,2727936,2110464,3391488 11621376,1059840,8672256,4147200,5962752,4930560,2119680,1787904,5114880,1419264,4663296,1622016,3262464,6137856,654336,7944192 2423808,3446784,2700288,5041152,5944320,7160832,3511296,4681728,4856832,7096320,2331648,3686400,2423808,5142528,1253376,6174720 5760000,1539072,11593728,3843072,4644864,3483648,1041408,2976768,3972096,3852288,2709504,7391232,4064256,5234688,3115008,5677056 4635648,2617344,2774016,6110208,2350080,6331392,3059712,2709504,4773888,7317504,5354496,2515968,9741312,3465216,903168,3437568 4460544,2304000,3299328,5548032,4921344,5179392,3022848,3658752,4672512,6405120,4303872,5216256,7870464,2460672,2239488,3640320 } [collective] type = ALLTOALLV team_size = 16 datatype = INT32 flags = COUNT_32BIT counts = { 1,1,2,2,3,3,4,4,0,0,0,0,0,0,0,0 2,2,3,3,4,4,5,5,0,0,0,0,0,0,0,0 3,3,4,4,5,5,6,6,0,0,0,0,0,0,0,0 4,4,5,5,6,6,7,7,0,0,0,0,0,0,0,0 5,5,6,6,7,7,8,8,0,0,0,0,0,0,0,0 6,6,7,7,8,8,9,9,0,0,0,0,0,0,0,0 7,7,8,8,9,9,1,1,0,0,0,0,0,0,0,0 8,8,9,9,1,1,1,1,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 }openucx-ucc-ec0bc8a/tools/perf/ucc_pt_benchmark.cc0000664000175000017500000004100315133731560022572 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include #include "ucc_pt_benchmark.h" #include "components/mc/ucc_mc.h" #include "ucc_perftest.h" #include "utils/ucc_coll_utils.h" #include "core/ucc_ee.h" #include "ucc_pt_coll.h" #include "generator/ucc_pt_generator.h" ucc_pt_benchmark::ucc_pt_benchmark(ucc_pt_benchmark_config cfg, ucc_pt_comm *communicator): config(cfg), comm(communicator) { /* RNG seed is passed to generators for isolated reproducibility */ if (cfg.gen.type == UCC_PT_GEN_TYPE_EXP) { generator = new ucc_pt_generator_exponential(cfg.min_count, cfg.max_count, 2, communicator->get_size(), cfg.op_type); } else if (cfg.gen.type == UCC_PT_GEN_TYPE_FILE) { generator = new ucc_pt_generator_file(cfg.gen.file_name, communicator->get_size(), communicator->get_rank(), cfg.op_type, cfg.gen.nrep); if (cfg.op_type != UCC_PT_OP_TYPE_ALLTOALLV) { throw std::runtime_error("Only ALLTOALLV is supported for file generator"); } } else if (cfg.gen.type == UCC_PT_GEN_TYPE_TRAFFIC_MATRIX) { if (cfg.op_type != UCC_PT_OP_TYPE_ALLTOALLV) { throw std::runtime_error( "Only ALLTOALLV is supported for traffic matrix generator"); } generator = new ucc_pt_generator_traffic_matrix( cfg.gen.matrix.kind, communicator->get_size(), communicator->get_rank(), cfg.dt, cfg.op_type, cfg.gen.nrep, cfg.gen.matrix.token_size_KB_mean, cfg.gen.matrix.num_tokens, cfg.gen.matrix.tgt_group_size_mean, cfg.seed); } else { /* assuming that the generator type is UCC_PT_GEN_TYPE_EXP */ generator = new ucc_pt_generator_exponential(cfg.min_count, cfg.max_count, 2, communicator->get_size(), cfg.op_type); } switch (cfg.op_type) { case UCC_PT_OP_TYPE_ALLGATHER: coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, cfg.map_type, comm, generator); break; case UCC_PT_OP_TYPE_ALLGATHERV: coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, comm, generator); break; case UCC_PT_OP_TYPE_ALLREDUCE: coll = new ucc_pt_coll_allreduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace, cfg.persistent, comm, generator); break; case UCC_PT_OP_TYPE_ALLTOALL: coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, cfg.map_type, comm, generator); break; case UCC_PT_OP_TYPE_ALLTOALLV: coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, comm, generator); break; case UCC_PT_OP_TYPE_BARRIER: coll = new ucc_pt_coll_barrier(comm, generator); break; case UCC_PT_OP_TYPE_BCAST: coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift, cfg.persistent, comm, generator); break; case UCC_PT_OP_TYPE_GATHER: coll = new ucc_pt_coll_gather(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, cfg.root_shift, comm, generator); break; case UCC_PT_OP_TYPE_GATHERV: coll = new ucc_pt_coll_gatherv(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, cfg.root_shift, comm, generator); break; case UCC_PT_OP_TYPE_REDUCE: coll = new ucc_pt_coll_reduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace, cfg.persistent, cfg.root_shift, comm, generator); break; case UCC_PT_OP_TYPE_REDUCE_SCATTER: coll = new ucc_pt_coll_reduce_scatter(cfg.dt, cfg.mt, cfg.op, cfg.inplace, cfg.persistent, comm, generator); break; case UCC_PT_OP_TYPE_REDUCE_SCATTERV: coll = new ucc_pt_coll_reduce_scatterv(cfg.dt, cfg.mt, cfg.op, cfg.inplace, cfg.persistent, comm, generator); break; case UCC_PT_OP_TYPE_SCATTER: coll = new ucc_pt_coll_scatter(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, cfg.root_shift, comm, generator); break; case UCC_PT_OP_TYPE_SCATTERV: coll = new ucc_pt_coll_scatterv(cfg.dt, cfg.mt, cfg.inplace, cfg.persistent, cfg.root_shift, comm, generator); break; case UCC_PT_OP_TYPE_MEMCPY: coll = new ucc_pt_op_memcpy(cfg.dt, cfg.mt, cfg.n_bufs, comm, generator); break; case UCC_PT_OP_TYPE_REDUCEDT: coll = new ucc_pt_op_reduce(cfg.dt, cfg.mt, cfg.op, cfg.n_bufs, comm, generator); break; case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: coll = new ucc_pt_op_reduce_strided(cfg.dt, cfg.mt, cfg.op, cfg.n_bufs, comm, generator); break; default: throw std::runtime_error("not supported collective"); } } ucc_status_t ucc_pt_benchmark::run_bench() noexcept { ucc_status_t st; ucc_pt_test_args_t args; double time; double time_min, time_max, time_avg; double total_time = 0; generator->reset(); print_header(); for (generator->reset(); generator->has_next(); generator->next()) { int iter = config.n_iter_small; int warmup = config.n_warmup_small; if (generator->get_count_max() >= config.large_thresh) { iter = config.n_iter_large; warmup = config.n_warmup_large; } args.coll_args.root = config.root; UCCCHECK_GOTO(coll->init_args(args), exit_err, st); if ((uint64_t)config.op_type < (uint64_t)UCC_COLL_TYPE_LAST) { UCCCHECK_GOTO(run_single_coll_test(args.coll_args, warmup, iter, time), free_coll, st); } else { UCCCHECK_GOTO(run_single_executor_test(args.executor_args, warmup, iter, time), free_coll, st); } comm->allreduce(&time, &time_min, 1, UCC_OP_MIN, UCC_DT_FLOAT64); comm->allreduce(&time, &time_max, 1, UCC_OP_MAX, UCC_DT_FLOAT64); comm->allreduce(&time, &time_avg, 1, UCC_OP_SUM, UCC_DT_FLOAT64); time_avg /= comm->get_size(); total_time += time_max; print_time(generator->get_src_count(), args, time_avg, time_min, time_max); coll->free_args(args); if (!coll->has_range()) { /* exit here since collective doesn't have count argument */ break; } } if (comm->get_rank() == 0) { std::cout << "Total time: " << total_time / 1000 << " ms" << std::endl; } return UCC_OK; free_coll: coll->free_args(args); exit_err: return st; } static inline double get_time_us(void) { struct timeval t; gettimeofday(&t, NULL); return t.tv_sec * 1e6 + t.tv_usec; } ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args, int nwarmup, int niter, double &time) noexcept { const bool triggered = config.triggered; const bool persistent = config.persistent; ucc_team_h team = comm->get_team(); ucc_context_h ctx = comm->get_context(); ucc_status_t st = UCC_OK; ucc_coll_req_h req; ucc_ee_h ee; ucc_ev_t comp_ev, *post_ev; UCCCHECK_GOTO(comm->barrier(), exit_err, st); time = 0; if (triggered) { try { ee = comm->get_ee(); } catch(std::exception &e) { std::cerr << e.what() << std::endl; return UCC_ERR_NO_MESSAGE; } /* dummy event, for benchmark purposes no real event required */ comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; comp_ev.ev_context = nullptr; comp_ev.ev_context_size = 0; } if (persistent) { UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st); } args.root = config.root % comm->get_size(); for (int i = 0; i < nwarmup + niter; i++) { double s = get_time_us(); if (!persistent) { UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st); } if (triggered) { comp_ev.req = req; UCCCHECK_GOTO(ucc_collective_triggered_post(ee, &comp_ev), free_req, st); UCCCHECK_GOTO(ucc_ee_get_event(ee, &post_ev), free_req, st); ucc_assert(post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); UCCCHECK_GOTO(ucc_ee_ack_event(ee, post_ev), free_req, st); } else { UCCCHECK_GOTO(ucc_collective_post(req), free_req, st); } st = ucc_collective_test(req); while (st > 0) { UCCCHECK_GOTO(ucc_context_progress(ctx), free_req, st); st = ucc_collective_test(req); } if (!persistent) { ucc_collective_finalize(req); } double f = get_time_us(); if (st != UCC_OK) { goto exit_err; } if (i >= nwarmup) { time += f - s; } args.root = (args.root + config.root_shift) % comm->get_size(); UCCCHECK_GOTO(comm->barrier(), exit_err, st); } if (persistent) { ucc_collective_finalize(req); } if (niter != 0) { time /= niter; } return UCC_OK; free_req: ucc_collective_finalize(req); exit_err: return st; } ucc_status_t ucc_pt_benchmark::run_single_executor_test(ucc_ee_executor_task_args_t args, int nwarmup, int niter, double &time) noexcept { const bool triggered = config.triggered; ucc_ee_executor_t *executor = comm->get_executor(); ucc_status_t st = UCC_OK; ucc_ee_h ee; ucc_ee_executor_task_t *task; time = 0; if (triggered) { try { ee = comm->get_ee(); } catch(std::exception &e) { std::cerr << e.what() << std::endl; return UCC_ERR_NO_MESSAGE; } UCCCHECK_GOTO(ucc_ee_executor_start(executor, ee->ee_context), exit_err, st); } else { UCCCHECK_GOTO(ucc_ee_executor_start(executor, nullptr), exit_err, st); } for (int i = 0; i < nwarmup + niter; i++) { double s = get_time_us(); UCCCHECK_GOTO(ucc_ee_executor_task_post(executor, &args, &task), stop_exec, st); st = ucc_ee_executor_task_test(task); while (st > 0) { st = ucc_ee_executor_task_test(task); } ucc_ee_executor_task_finalize(task); double f = get_time_us(); if (st != UCC_OK) { goto exit_err; } if (i >= nwarmup) { time += f - s; } } UCCCHECK_GOTO(ucc_ee_executor_stop(executor), exit_err, st); if (niter != 0) { time /= niter; } return UCC_OK; stop_exec: ucc_ee_executor_stop(executor); exit_err: return st; } void ucc_pt_benchmark::print_header() { if (comm->get_rank() == 0) { std::ios iostate(nullptr); iostate.copyfmt(std::cout); std::cout << std::left << std::setw(24) << "Collective: " << ucc_pt_op_type_str(config.op_type) << std::endl; std::cout << std::left << std::setw(24) << "Memory type: " << ucc_memory_type_names[config.mt] << std::endl; std::cout << std::left << std::setw(24) << "Datatype: " << ucc_datatype_str(config.dt) << std::endl; std::cout << std::left << std::setw(24) << "Reduction: " << (coll->has_reduction() ? ucc_reduction_op_str(config.op): "N/A") << std::endl; std::cout << std::left << std::setw(24) << "Inplace: " << (coll->has_inplace() ? std::to_string(config.inplace): "N/A") << std::endl; std::cout << std::left << std::setw(24) << "Warmup:" << std::endl << std::left << std::setw(24) << " small" << config.n_warmup_small << std::endl << std::left << std::setw(24) << " large" << config.n_warmup_large << std::endl; std::cout << std::left << std::setw(24) << "Iterations:" << std::endl << std::left << std::setw(24) << " small" << config.n_iter_small << std::endl << std::left << std::setw(24) << " large" << config.n_iter_large << std::endl; std::cout.copyfmt(iostate); std::cout << std::endl; std::cout << std::setw(12) << "Count" << std::setw(12) << "Size" << std::setw(24) << "Time, us"; if (config.full_print) { std::cout << std::setw(42) << "Bus Bandwidth, GB/s"; } std::cout << std::endl; std::cout << std::setw(36) << "avg" << std::setw(12) << "min" << std::setw(12) << "max"; if (config.full_print) { std::cout << std::setw(12) << "avg" << std::setw(12) << "max" << std::setw(12) << "min"; } std::cout << std::endl; } } void ucc_pt_benchmark::print_time(size_t count, ucc_pt_test_args_t args, double time_avg, double time_min, double time_max) { size_t size = count * ucc_dt_size(config.dt); int gsize = comm->get_size(); if (comm->get_rank() == 0) { std::ios iostate(nullptr); iostate.copyfmt(std::cout); std::cout << std::setprecision(2) << std::fixed; std::cout << std::setw(12) << (coll->has_range() ? std::to_string(count): "N/A") << std::setw(12) << (coll->has_range() ? std::to_string(size): "N/A") << std::setw(12) << time_avg << std::setw(12) << time_min << std::setw(12) << time_max; if (config.full_print) { if (!coll->has_bw()) { std::cout << std::setw(12) << "N/A" << std::setw(12) << "N/A" << std::setw(12) << "N/A"; } else { if (config.op_type == UCC_PT_OP_TYPE_GATHER || config.op_type == UCC_PT_OP_TYPE_SCATTER) { std::cout << std::setw(12) << "N/A" << std::setw(12) << "N/A" << std::setw(12) << coll->get_bw(time_max, gsize, args); } else { std::cout << std::setw(12) << coll->get_bw(time_avg, gsize, args) << std::setw(12) << coll->get_bw(time_min, gsize, args) << std::setw(12) << coll->get_bw(time_max, gsize, args); } } } std::cout << std::endl; std::cout.copyfmt(iostate); } } ucc_pt_benchmark::~ucc_pt_benchmark() { delete coll; delete generator; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_reduce.cc0000664000175000017500000000644415133731560023132 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_reduce::ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = true; has_range_ = true; has_bw_ = true; root_shift_ = root_shift; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_REDUCE; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_reduce::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st_src, st_dst; coll_args.root = test_args.coll_args.root; args = coll_args; args.src.info.count = generator->get_src_count(); args.dst.info.count = generator->get_dst_count(); bool is_root = (comm->get_rank() == args.root); if (is_root || root_shift_) { UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info.mem_type), exit, st_dst); args.dst.info.buffer = dst_header->addr; } if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { UCCCHECK_GOTO(ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), free_dst, st_src); args.src.info.buffer = src_header->addr; } return UCC_OK; free_dst: if ((is_root || root_shift_) && st_dst == UCC_OK) { ucc_pt_free(dst_header); } return st_src; exit: return st_dst; } void ucc_pt_coll_reduce::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; bool is_root = (comm->get_rank() == args.root); if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { ucc_pt_free(src_header); } if (is_root || root_shift_) { ucc_pt_free(dst_header); } } float ucc_pt_coll_reduce::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float S = args.src.info.count * ucc_dt_size(args.src.info.datatype); return S / time_ms / 1000.0; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_config.h0000664000175000017500000000772515133731560021764 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_CONFIG_H #define UCC_PT_CONFIG_H #include "utils/ucc_log.h" #include #include #include #include #include #include #include #define UCC_PT_DEFAULT_N_BUFS 0 enum ucc_pt_bootstrap_type_t { UCC_PT_BOOTSTRAP_MPI, UCC_PT_BOOTSTRAP_UCX }; struct ucc_pt_bootstrap_config { ucc_pt_bootstrap_type_t bootstrap; }; struct ucc_pt_comm_config { ucc_memory_type_t mt; }; typedef enum { UCC_PT_OP_TYPE_ALLGATHER = UCC_COLL_TYPE_ALLGATHER, UCC_PT_OP_TYPE_ALLGATHERV = UCC_COLL_TYPE_ALLGATHERV, UCC_PT_OP_TYPE_ALLREDUCE = UCC_COLL_TYPE_ALLREDUCE, UCC_PT_OP_TYPE_ALLTOALL = UCC_COLL_TYPE_ALLTOALL, UCC_PT_OP_TYPE_ALLTOALLV = UCC_COLL_TYPE_ALLTOALLV, UCC_PT_OP_TYPE_BARRIER = UCC_COLL_TYPE_BARRIER, UCC_PT_OP_TYPE_BCAST = UCC_COLL_TYPE_BCAST, UCC_PT_OP_TYPE_FANIN = UCC_COLL_TYPE_FANIN, UCC_PT_OP_TYPE_FANOUT = UCC_COLL_TYPE_FANOUT, UCC_PT_OP_TYPE_GATHER = UCC_COLL_TYPE_GATHER, UCC_PT_OP_TYPE_GATHERV = UCC_COLL_TYPE_GATHERV, UCC_PT_OP_TYPE_REDUCE = UCC_COLL_TYPE_REDUCE, UCC_PT_OP_TYPE_REDUCE_SCATTER = UCC_COLL_TYPE_REDUCE_SCATTER, UCC_PT_OP_TYPE_REDUCE_SCATTERV = UCC_COLL_TYPE_REDUCE_SCATTERV, UCC_PT_OP_TYPE_SCATTER = UCC_COLL_TYPE_SCATTER, UCC_PT_OP_TYPE_SCATTERV = UCC_COLL_TYPE_SCATTERV, UCC_PT_OP_TYPE_MEMCPY = UCC_COLL_TYPE_LAST + 1, UCC_PT_OP_TYPE_REDUCEDT, UCC_PT_OP_TYPE_REDUCEDT_STRIDED, UCC_PT_OP_TYPE_LAST } ucc_pt_op_type_t; typedef enum { UCC_PT_MAP_TYPE_NONE, UCC_PT_MAP_TYPE_LOCAL, UCC_PT_MAP_TYPE_GLOBAL, UCC_PT_MAP_TYPE_LAST } ucc_pt_map_type_t; typedef enum { UCC_PT_GEN_TYPE_EXP, UCC_PT_GEN_TYPE_FILE, UCC_PT_GEN_TYPE_TRAFFIC_MATRIX } ucc_pt_gen_type_t; static inline const char* ucc_pt_op_type_str(ucc_pt_op_type_t op) { if ((uint64_t)op < (uint64_t)UCC_COLL_TYPE_LAST) { return ucc_coll_type_str((ucc_coll_type_t)op); } switch(op) { case UCC_PT_OP_TYPE_MEMCPY: return "Memcpy"; case UCC_PT_OP_TYPE_REDUCEDT: return "Reduce DT"; case UCC_PT_OP_TYPE_REDUCEDT_STRIDED: return "Reduce DT strided"; default: break; } return NULL; } struct ucc_pt_gen_traffic_matrix_config { int kind; int token_size_KB_mean; int token_size_KB_std; int tgt_group_size_mean; int tgt_group_size_std; int num_tokens; int num_hl_ranks; double bias_factor; }; struct ucc_pt_gen_config { ucc_pt_gen_type_t type; size_t nrep; // Number of repetitions for file/matrix-based generation std::string file_name; union { struct { size_t min; size_t max; } exp; ucc_pt_gen_traffic_matrix_config matrix; }; }; struct ucc_pt_benchmark_config { ucc_pt_op_type_t op_type; size_t min_count; size_t max_count; ucc_datatype_t dt; ucc_memory_type_t mt; ucc_reduction_op_t op; ucc_pt_map_type_t map_type; bool inplace; bool persistent; bool triggered; size_t large_thresh; int n_iter_small; int n_warmup_small; int n_iter_large; int n_warmup_large; int n_bufs; bool full_print; int root; int root_shift; int mult_factor; uint64_t seed; ucc_pt_gen_config gen; }; struct ucc_pt_config { ucc_pt_bootstrap_config bootstrap; ucc_pt_comm_config comm; ucc_pt_benchmark_config bench; ucc_pt_config(); ucc_status_t process_args(int argc, char *argv[]); void print_help(); }; #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_op_memcpy.cc0000664000175000017500000000604215133731560022634 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_op_memcpy::ucc_pt_op_memcpy(ucc_datatype_t dt, ucc_memory_type mt, int nbufs, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = false; has_reduction_ = false; has_range_ = true; has_bw_ = true; if (nbufs == UCC_PT_DEFAULT_N_BUFS) { nbufs = 1; } if (nbufs > UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS) { throw std::runtime_error("max supported number of copy buffer is " STR(UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS)); } data_type = dt; mem_type = mt; num_bufs = nbufs; } ucc_status_t ucc_pt_op_memcpy::init_args(ucc_pt_test_args_t &test_args) { ucc_ee_executor_task_args_t &args = test_args.executor_args; size_t dt_size = ucc_dt_size(data_type); size_t size = generator->get_src_count() * dt_size; ucc_status_t st; int i; UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, num_bufs * size, mem_type), exit, st); UCCCHECK_GOTO(ucc_pt_alloc(&src_header, num_bufs * size, mem_type), free_dst, st); if (num_bufs == 1) { args.task_type = UCC_EE_EXECUTOR_TASK_COPY; args.copy.dst = dst_header->addr; args.copy.src = src_header->addr; args.copy.len = size; } else { args.task_type = UCC_EE_EXECUTOR_TASK_COPY_MULTI; args.copy_multi.num_vectors = num_bufs; for (i = 0; i < num_bufs; i++) { args.copy_multi.src[i] = PTR_OFFSET(src_header->addr, size * i); args.copy_multi.dst[i] = PTR_OFFSET(dst_header->addr, size * i); args.copy_multi.counts[i] = generator->get_src_count(); } } return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } float ucc_pt_op_memcpy::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_ee_executor_task_args_t &args = test_args.executor_args; float S; int i; if (args.task_type == UCC_EE_EXECUTOR_TASK_COPY) { S = args.copy.len; } else { S = 0; for (i = 0; i < args.copy_multi.num_vectors; i++) { S += args.copy_multi.counts[i]; } } return 2 * (S / time_ms) / 1000.0; } void ucc_pt_op_memcpy::free_args(ucc_pt_test_args_t &test_args) { ucc_pt_free(src_header); ucc_pt_free(dst_header); } openucx-ucc-ec0bc8a/tools/perf/ucc_perftest.cc0000664000175000017500000000232515133731560021775 0ustar alastairalastair#include #include "ucc_pt_comm.h" #include "ucc_pt_config.h" #include "ucc_pt_coll.h" #include "ucc_pt_cuda.h" #include "ucc_pt_rocm.h" #include "ucc_pt_benchmark.h" int main(int argc, char *argv[]) { ucc_pt_config pt_config; ucc_pt_comm *comm; ucc_pt_benchmark *bench; ucc_status_t st; pt_config.process_args(argc, argv); ucc_pt_cuda_init(); ucc_pt_rocm_init(); try { comm = new ucc_pt_comm(pt_config.comm); } catch(std::exception &e) { std::cerr << e.what() << std::endl; std::exit(1); } st = comm->init(); if (st != UCC_OK) { delete comm; std::exit(1); } try { bench = new ucc_pt_benchmark(pt_config.bench, comm); } catch(std::exception &e) { std::cerr << e.what() << std::endl; comm->finalize(); delete comm; std::exit(1); } st = bench->run_bench(); if (st != UCC_OK) { std::cerr << "Benchmark failed with status " << st << " " << ucc_status_string(st) << std::endl; delete bench; comm->finalize(); delete comm; std::exit(1); } delete bench; comm->finalize(); delete comm; return 0; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_allreduce.cc0000664000175000017500000000563215133731560023621 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = true; has_range_ = true; has_bw_ = true; root_shift_ = 0; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE; coll_args.op = op; coll_args.src.info.datatype = dt; coll_args.dst.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_allreduce::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st = UCC_OK; args = coll_args; args.src.info.count = generator->get_src_count(); args.dst.info.count = generator->get_dst_count(); UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info.mem_type), exit, st); args.dst.info.buffer = dst_header->addr; if (!UCC_IS_INPLACE(args)) { UCCCHECK_GOTO(ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), free_dst, st); args.src.info.buffer = src_header->addr; } return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } void ucc_pt_coll_allreduce::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; if (!UCC_IS_INPLACE(args)) { ucc_pt_free(src_header); } ucc_pt_free(dst_header); } float ucc_pt_coll_allreduce::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float N = grsize; float S = args.src.info.count * ucc_dt_size(args.src.info.datatype); return (S / time_ms) * (2 * (N - 1) / N) / 1000.0; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_scatter.cc0000664000175000017500000000645215133731560023327 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_scatter::ucc_pt_coll_scatter(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = true; root_shift_ = root_shift; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_SCATTER; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_scatter::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.dst.info.datatype); ucc_status_t st_src = UCC_OK, st_dst = UCC_OK; bool is_root; coll_args.root = test_args.coll_args.root; args = coll_args; args.dst.info.count = generator->get_dst_count(); is_root = (comm->get_rank() == args.root); if (is_root || root_shift_) { args.src.info.count = generator->get_src_count(); UCCCHECK_GOTO( ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), exit, st_src); args.src.info.buffer = src_header->addr; } if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info.mem_type), free_src, st_dst); args.dst.info.buffer = dst_header->addr; } return UCC_OK; free_src: if ((is_root || root_shift_) && st_src == UCC_OK) { ucc_pt_free(src_header); } return st_dst; exit: return st_src; } float ucc_pt_coll_scatter::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float S = args.dst.info.count * ucc_dt_size(args.dst.info.datatype); float N = grsize - 1; return (S * N) / time_ms / 1000.0; } void ucc_pt_coll_scatter::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; bool is_root = (comm->get_rank() == args.root); if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { ucc_pt_free(dst_header); } if (is_root || root_shift_) { ucc_pt_free(src_header); } } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_alltoall.cc0000664000175000017500000002035315133731560023462 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_map_type_t map_type, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { size_t src_count_size = generator->get_src_count_max() * ucc_dt_size(dt); size_t dst_count_size = generator->get_dst_count_max() * ucc_dt_size(dt); ucc_status_t st; has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = true; root_shift_ = 0; UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, dst_count_size, mt), exit, st); if (!is_inplace) { UCCCHECK_GOTO(ucc_pt_alloc(&src_header, src_count_size, mt), exit, st); } coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_ALLTOALL; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; coll_args.dst.info.buffer = dst_header->addr; if (!is_inplace) { coll_args.src.info.buffer = src_header->addr; } if (map_type == UCC_PT_MAP_TYPE_LOCAL) { ucc_context_h ctx = comm->get_context(); ucc_mem_map_t segments[1]; ucc_mem_map_params_t mem_map_params; size_t dst_memh_size, src_memh_size; mem_map_params.n_segments = 1; mem_map_params.segments = segments; mem_map_params.segments[0].address = dst_header->addr; mem_map_params.segments[0].len = dst_count_size; UCCCHECK_GOTO(ucc_mem_map(ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &dst_memh_size, &dst_memh), exit, st); coll_args.dst_memh.local_memh = dst_memh; coll_args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH; if (!is_inplace) { mem_map_params.segments[0].address = src_header->addr; mem_map_params.segments[0].len = src_count_size; UCCCHECK_GOTO(ucc_mem_map(ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &src_memh_size, &src_memh), exit, st); coll_args.src_memh.local_memh = src_memh; coll_args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH; } } else if (map_type == UCC_PT_MAP_TYPE_GLOBAL) { ucc_context_h ctx = comm->get_context(); ucc_mem_map_t segments[1]; ucc_mem_map_params_t mem_map_params; uint64_t dst_memh_size, src_memh_size; uint64_t dst_memh_size_max, src_memh_size_max; coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; mem_map_params.n_segments = 1; mem_map_params.segments = segments; mem_map_params.segments[0].address = dst_header->addr; mem_map_params.segments[0].len = dst_count_size; UCCCHECK_GOTO(ucc_mem_map(ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &dst_memh_size, &dst_memh), exit, st); // Synchronize memh size across all ranks to ensure consistent buffer allocation comm->allreduce(&dst_memh_size, &dst_memh_size_max, 1, UCC_OP_MAX, UCC_DT_UINT64); dst_memh_global = new ucc_mem_map_mem_h[comm->get_size()]; for (int i = 0; i < comm->get_size(); i++) { dst_memh_global[i] = new char[dst_memh_size_max]; if (i == comm->get_rank()) { memcpy(dst_memh_global[i], dst_memh, dst_memh_size); } comm->bcast(dst_memh_global[i], dst_memh_size_max, i); } for (int i = 0; i < comm->get_size(); i++) { ucc_mem_map(ctx, UCC_MEM_MAP_MODE_IMPORT, &mem_map_params, &dst_memh_size_max, &dst_memh_global[i]); } coll_args.dst_memh.global_memh = dst_memh_global; coll_args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH; coll_args.flags |= UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL; if (!is_inplace) { mem_map_params.segments[0].address = src_header->addr; mem_map_params.segments[0].len = src_count_size; UCCCHECK_GOTO(ucc_mem_map(ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &src_memh_size, &src_memh), exit, st); // Synchronize memh size across all ranks to ensure consistent buffer allocation comm->allreduce(&src_memh_size, &src_memh_size_max, 1, UCC_OP_MAX, UCC_DT_UINT64); src_memh_global = new ucc_mem_map_mem_h[comm->get_size()]; for (int i = 0; i < comm->get_size(); i++) { src_memh_global[i] = new char[src_memh_size_max]; if (i == comm->get_rank()) { memcpy(src_memh_global[i], src_memh, src_memh_size); } comm->bcast(src_memh_global[i], src_memh_size_max, i); } for (int i = 0; i < comm->get_size(); i++) { ucc_mem_map(ctx, UCC_MEM_MAP_MODE_IMPORT, &mem_map_params, &src_memh_size_max, &src_memh_global[i]); } coll_args.src_memh.global_memh = src_memh_global; coll_args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH; coll_args.flags |= UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL; } } else if (map_type != UCC_PT_MAP_TYPE_NONE) { std::cerr << "unsupported map type for perftest alltoall" << std::endl; goto exit; } coll_args.global_work_buffer = comm->get_onesided_buf(); coll_args.mask |= UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; if (is_inplace) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } return; exit: if (dst_header) { ucc_pt_free(dst_header); dst_header = NULL; } if (src_header) { ucc_pt_free(src_header); src_header = NULL; } throw std::runtime_error("failed to initialize alltoall arguments"); } ucc_status_t ucc_pt_coll_alltoall::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; args = coll_args; args.dst.info.count = generator->get_dst_count(); if (!UCC_IS_INPLACE(args)) { args.src.info.count = generator->get_src_count(); } return UCC_OK; } float ucc_pt_coll_alltoall::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float N = grsize; float S = args.src.info.count * ucc_dt_size(args.src.info.datatype); return (S / time_ms) * ((N - 1) / N) / 1000.0; } ucc_pt_coll_alltoall::~ucc_pt_coll_alltoall() { if (src_memh) { ucc_mem_unmap(&src_memh); } if (dst_memh) { ucc_mem_unmap(&dst_memh); } if (src_header) { ucc_pt_free(src_header); } if (dst_header) { ucc_pt_free(dst_header); } if (dst_memh_global) { for (int i = 0; i < comm->get_size(); i++) { if (dst_memh_global[i]) { ucc_mem_unmap(&dst_memh_global[i]); delete[] static_cast(dst_memh_global[i]); } } delete[] dst_memh_global; } if (src_memh_global) { for (int i = 0; i < comm->get_size(); i++) { if (src_memh_global[i]) { ucc_mem_unmap(&src_memh_global[i]); delete[] static_cast(src_memh_global[i]); } } delete[] src_memh_global; } } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_cuda.cc0000664000175000017500000000312715133731560021561 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_cuda.h" #include #include #include ucc_pt_cuda_iface_t ucc_pt_cuda_iface = { .available = 0, }; #define LOAD_CUDA_SYM(_sym, _pt_sym) ({ \ void *h = dlsym(handle, _sym); \ if (dlerror() != NULL) { \ return; \ } \ ucc_pt_cuda_iface. _pt_sym = \ reinterpret_cast(h); \ }) void ucc_pt_cuda_init(void) { void *handle; handle = dlopen ("libcudart.so", RTLD_LAZY); if (!handle) { return; } LOAD_CUDA_SYM("cudaGetDeviceCount", getDeviceCount); LOAD_CUDA_SYM("cudaSetDevice", setDevice); LOAD_CUDA_SYM("cudaGetErrorString", getErrorString); LOAD_CUDA_SYM("cudaStreamCreateWithFlags", streamCreateWithFlags); LOAD_CUDA_SYM("cudaStreamDestroy", streamDestroy); LOAD_CUDA_SYM("cudaMalloc", cudaMalloc); LOAD_CUDA_SYM("cudaFree", cudaFree); LOAD_CUDA_SYM("cudaMemset", cudaMemset); LOAD_CUDA_SYM("cudaMallocManaged", cudaMallocManaged); LOAD_CUDA_SYM("cudaGetDeviceProperties", cudaGetDeviceProperties); LOAD_CUDA_SYM("cudaDeviceGetPCIBusId", cudaDeviceGetPCIBusId); ucc_pt_cuda_iface.available = 1; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_barrier.cc0000664000175000017500000000167515133731560023312 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_barrier::ucc_pt_coll_barrier(ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = false; has_reduction_ = false; has_range_ = false; has_bw_ = false; root_shift_ = 0; coll_args.mask = 0; coll_args.coll_type = UCC_COLL_TYPE_BARRIER; } ucc_status_t ucc_pt_coll_barrier::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; args = coll_args; return UCC_OK; } void ucc_pt_coll_barrier::free_args(ucc_pt_test_args_t &test_args) { return; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_allgatherv.cc0000664000175000017500000000513215133731560024005 0ustar alastairalastair/** * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = false; root_shift_ = 0; coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; if (is_inplace) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_allgatherv::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st; args = coll_args; args.dst.info_v.counts = generator->get_dst_counts(); args.dst.info_v.displacements = generator->get_dst_displs(); UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info_v.mem_type), exit, st); args.dst.info_v.buffer = dst_header->addr; if (!UCC_IS_INPLACE(args)) { args.src.info.count = generator->get_src_count(); UCCCHECK_GOTO( ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), free_dst, st); args.src.info.buffer = src_header->addr; } return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } void ucc_pt_coll_allgatherv::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; if (!UCC_IS_INPLACE(args)) { ucc_pt_free(src_header); } ucc_pt_free(dst_header); } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_bcast.cc0000664000175000017500000000416415133731560022754 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_bcast::ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, int root_shift, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = false; has_reduction_ = false; has_range_ = true; has_bw_ = true; root_shift_ = root_shift; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_BCAST; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_bcast::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st; coll_args.root = test_args.coll_args.root; args = coll_args; args.src.info.count = generator->get_src_count(); UCCCHECK_GOTO(ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), exit, st); args.src.info.buffer = src_header->addr; exit: return st; } void ucc_pt_coll_bcast::free_args(ucc_pt_test_args_t &test_args) { ucc_pt_free(src_header); } float ucc_pt_coll_bcast::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float S = args.src.info.count * ucc_dt_size(args.src.info.datatype); return S / time_ms / 1000.0; } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_allgather.cc0000664000175000017500000001101715133731560023616 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_map_type_t map_type, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { size_t src_count_size = generator->get_src_count_max() * ucc_dt_size(dt); size_t dst_count_size = generator->get_dst_count_max() * ucc_dt_size(dt); ucc_status_t st; has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = true; root_shift_ = 0; UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, dst_count_size, mt), exit, st); if (!is_inplace) { UCCCHECK_GOTO(ucc_pt_alloc(&src_header, src_count_size, mt), exit, st); } coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.buffer = dst_header->addr; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } else { coll_args.src.info.buffer = src_header->addr; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } if (map_type == UCC_PT_MAP_TYPE_LOCAL) { ucc_context_h ctx = comm->get_context(); ucc_mem_map_t segments[1]; ucc_mem_map_params_t mem_map_params; size_t dst_memh_size, src_memh_size; mem_map_params.n_segments = 1; mem_map_params.segments = segments; mem_map_params.segments[0].address = dst_header->addr; mem_map_params.segments[0].len = dst_count_size; UCCCHECK_GOTO(ucc_mem_map(ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &dst_memh_size, &dst_memh), exit, st); coll_args.dst_memh.local_memh = dst_memh; coll_args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH; if (!is_inplace) { mem_map_params.segments[0].address = src_header->addr; mem_map_params.segments[0].len = src_count_size; UCCCHECK_GOTO(ucc_mem_map(ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &src_memh_size, &src_memh), exit, st); coll_args.src_memh.local_memh = src_memh; coll_args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH; } } else if (map_type != UCC_PT_MAP_TYPE_NONE) { std::cerr << "Only local mapping is supported for perftest allgather" << std::endl; goto exit; } return; exit: if (dst_header) { ucc_pt_free(dst_header); dst_header = NULL; } if (src_header) { ucc_pt_free(src_header); src_header = NULL; } throw std::runtime_error("failed to initialize allgather arguments"); } ucc_status_t ucc_pt_coll_allgather::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t single_rank_count = generator->get_src_count(); args = coll_args; args.dst.info.count = single_rank_count * comm->get_size(); if (!UCC_IS_INPLACE(args)) { args.src.info.count = single_rank_count; } return UCC_OK; } float ucc_pt_coll_allgather::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float N = grsize; float S = args.dst.info.count * ucc_dt_size(args.dst.info.datatype); return (S / time_ms) * ((N - 1) / N) / 1000.0; } ucc_pt_coll_allgather::~ucc_pt_coll_allgather() { if (src_header) { ucc_pt_free(src_header); } if (dst_header) { ucc_pt_free(dst_header); } if (src_memh) { ucc_mem_unmap(&src_memh); } if (dst_memh) { ucc_mem_unmap(&dst_memh); } }openucx-ucc-ec0bc8a/tools/perf/ucc_pt_rocm.h0000664000175000017500000000303215133731560021442 0ustar alastairalastair/** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (C) Advanced Micro Devices, Inc. 2022. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ #ifndef UCC_PT_ROCM_H #define UCC_PT_ROCM_H #include typedef struct ucc_pt_rocm_iface { int available; int (*getDeviceCount)(int* count); int (*setDevice)(int device); char* (*getErrorString)(int err); } ucc_pt_rocm_iface_t; extern ucc_pt_rocm_iface_t ucc_pt_rocm_iface; void ucc_pt_rocm_init(void); #define hipSuccess 0 #define STR(x) #x #define HIP_CHECK(_call) \ do { \ int _status = (_call); \ if (hipSuccess != _status) { \ std::cerr << "UCC perftest error: " << \ ucc_pt_rocm_iface.getErrorString(_status) \ << " in " << STR(_call) << "\n"; \ return _status; \ } \ } while (0) static inline int ucc_pt_rocmGetDeviceCount(int *count) { if (!ucc_pt_rocm_iface.available) { return 1; } HIP_CHECK(ucc_pt_rocm_iface.getDeviceCount(count)); return 0; } static inline int ucc_pt_rocmSetDevice(int device) { if (!ucc_pt_rocm_iface.available) { return 1; } HIP_CHECK(ucc_pt_rocm_iface.setDevice(device)); return 0; } #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_bootstrap.h0000664000175000017500000000437315133731560022530 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_BOOTSTRAP_H #define UCC_PT_BOOTSTRAP_H #include #include #include #include #include class ucc_pt_bootstrap { protected: size_t node_hash; ucc_context_oob_coll_t context_oob; ucc_team_oob_coll_t team_oob; int ppn; int local_rank; void find_ppn() { int comm_size = get_size(); int comm_rank = get_rank(); size_t *hashes = new size_t[comm_size]; ucc_status_t st; void *req; ppn = 0; local_rank = 0; st = team_oob.allgather(&node_hash, hashes, sizeof(node_hash), team_oob.coll_info, &req); if (st != UCC_OK) { goto exit_err; } do { st = team_oob.req_test(req); } while (st == UCC_INPROGRESS); if (st != UCC_OK) { goto exit_err; } team_oob.req_free(req); for (int i = 0; i < comm_size; i++) { if (i == comm_rank) { break; } if (node_hash == hashes[i]) { local_rank++; } } for (int i = 0; i < comm_size; i++) { if (node_hash == hashes[i]) { ppn++; } } delete[] hashes; return; exit_err: std::cerr <<"failed to find ppn" <{}(std::string(hostname)); ppn = -1; local_rank = -1; } virtual int get_rank() = 0; virtual int get_size() = 0; int get_ppn() { if (ppn == -1) { find_ppn(); } return ppn; } int get_local_rank() { if (local_rank == -1) { find_ppn(); } return local_rank; } virtual ~ucc_pt_bootstrap() {}; ucc_context_oob_coll_t get_context_oob() { return context_oob; } ucc_team_oob_coll_t get_team_oob() { return team_oob; } }; #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_alltoallv.cc0000664000175000017500000000734315133731560023654 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_alltoallv::ucc_pt_coll_alltoallv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { size_t src_count_max = generator->get_src_count_max(); size_t dst_count_max = generator->get_dst_count_max(); ucc_status_t st; has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = true; root_shift_ = 0; UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, dst_count_max * ucc_dt_size(dt), mt), exit, st); if (!is_inplace) { UCCCHECK_GOTO(ucc_pt_alloc(&src_header, src_count_max * ucc_dt_size(dt), mt), exit, st); } coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.coll_type = UCC_COLL_TYPE_ALLTOALLV; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; coll_args.dst.info_v.buffer = dst_header->addr; coll_args.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; if (is_inplace) { coll_args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } else { coll_args.src.info_v.buffer = src_header->addr; coll_args.src.info_v.datatype = dt; coll_args.src.info_v.mem_type = mt; } if (is_persistent) { coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } return; exit: if (dst_header) { ucc_pt_free(dst_header); dst_header = NULL; } if (src_header) { ucc_pt_free(src_header); src_header = NULL; } throw std::runtime_error("failed to initialize alltoallv arguments"); } ucc_status_t ucc_pt_coll_alltoallv::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; args = coll_args; args.src.info_v.counts = (ucc_count_t *) generator->get_src_counts(); args.src.info_v.displacements = (ucc_aint_t *) generator->get_src_displs(); args.dst.info_v.counts = (ucc_count_t *) generator->get_dst_counts(); args.dst.info_v.displacements = (ucc_aint_t *) generator->get_dst_displs(); return UCC_OK; } float ucc_pt_coll_alltoallv::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float N = grsize; float S = 0; size_t src_size = 0, dst_size = 0; for (int i = 0; i < grsize; i++) { src_size += ucc_coll_args_get_count(&args, args.src.info_v.counts, i); dst_size += ucc_coll_args_get_count(&args, args.dst.info_v.counts, i); } src_size *= ucc_dt_size(args.src.info_v.datatype); dst_size *= ucc_dt_size(args.dst.info_v.datatype); S = src_size > dst_size ? src_size : dst_size; return (S / time_ms) * ((N - 1) / N) / 1000.0; } ucc_pt_coll_alltoallv::~ucc_pt_coll_alltoallv() { if (src_header) { ucc_pt_free(src_header); } if (dst_header) { ucc_pt_free(dst_header); } }openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_gatherv.cc0000664000175000017500000000624715133731560023324 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_gatherv::ucc_pt_coll_gatherv(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = false; root_shift_ = root_shift; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_GATHERV; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info_v.datatype = dt; coll_args.dst.info_v.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_gatherv::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st; bool is_root; coll_args.root = test_args.coll_args.root; args = coll_args; is_root = (comm->get_rank() == args.root); if (is_root || root_shift_) { args.dst.info_v.counts = generator->get_dst_counts(); args.dst.info_v.displacements = generator->get_dst_displs(); UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info_v.mem_type), exit, st); args.dst.info_v.buffer = dst_header->addr; } if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { args.src.info.count = generator->get_src_count(); st = ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type); if (UCC_OK != st) { std::cerr << "UCC perftest error: " << ucc_status_string(st) << " in " << STR(_call) << "\n"; if (is_root || root_shift_) { goto free_dst; } else { goto exit; } } args.src.info.buffer = src_header->addr; } return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } void ucc_pt_coll_gatherv::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; bool is_root = (comm->get_rank() == args.root); if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { ucc_pt_free(src_header); } if (is_root || root_shift_) { ucc_pt_free(dst_header); } } openucx-ucc-ec0bc8a/tools/perf/ucc_perftest.h0000664000175000017500000000317315133731560021641 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include #include #include "config.h" extern "C" { #include "utils/ucc_malloc.h" } #define STR(x) #x #define UCCCHECK_GOTO(_call, _label, _status) \ do { \ _status = (_call); \ if (UCC_OK != _status) { \ std::cerr << "UCC perftest error: " << ucc_status_string(_status) \ << " in " << STR(_call) << __FILE__ << ":" \ << __LINE__<< "\n"; \ goto _label; \ } \ } while (0) #define UCC_MALLOC_CHECK_GOTO(_obj, _label, _status) \ do { \ if (!(_obj)) { \ _status = UCC_ERR_NO_MEMORY; \ std::cerr << "UCC perftest error: " << ucc_status_string(_status) \ << "\n"; \ goto _label; \ } \ } while (0) openucx-ucc-ec0bc8a/tools/perf/ucc_pt_cuda.h0000664000175000017500000001057615133731560021431 0ustar alastairalastair/** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_CUDA_H #define UCC_PT_CUDA_H #include #include #define cudaSuccess 0 #define cudaStreamNonBlocking 0x01 /**< Stream does not synchronize with stream 0 (the NULL stream) */ #define cudaMemAttachGlobal 0x01 /**< Memory can be accessed by any stream on any device*/ typedef struct CUStream_st *cudaStream_t; typedef struct cudaDeviceProp { char name[256]; char padding[2048]; /* take extra space to avoid future changes, real size is 1080 - 256 = 824 on x86 */ } cudaDeviceProp; #define STR(x) #x #define CUDA_CHECK(_call) \ do { \ int _status = (_call); \ if (cudaSuccess != _status) { \ std::cerr << "UCC perftest error: " << \ ucc_pt_cuda_iface.getErrorString(_status) \ << " in " << STR(_call) << "\n"; \ return _status; \ } \ } while (0) typedef struct ucc_pt_cuda_iface { int available; int (*getDeviceCount)(int* count); int (*setDevice)(int device); int (*getDeviceInfo)(std::string &info); int (*streamCreateWithFlags)(cudaStream_t *stream, unsigned int flags); int (*streamDestroy)(cudaStream_t stream); char* (*getErrorString)(int err); int (*cudaMalloc)(void **devptr, size_t size); int (*cudaMallocManaged)(void **ptr, size_t size, unsigned int flags); int (*cudaFree)(void *devptr); int (*cudaMemset)(void *devptr, int value, size_t count); int (*cudaGetDeviceProperties)(void *prop, int device); int (*cudaDeviceGetPCIBusId)(char *pciBusId, int len, int device); } ucc_pt_cuda_iface_t; extern ucc_pt_cuda_iface_t ucc_pt_cuda_iface; void ucc_pt_cuda_init(void); static inline int ucc_pt_cudaGetDeviceCount(int *count) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.getDeviceCount(count)); return 0; } static inline int ucc_pt_cudaSetDevice(int device) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.setDevice(device)); return 0; } static inline int ucc_pt_cudaStreamCreateWithFlags(cudaStream_t *stream, unsigned int flags) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.streamCreateWithFlags(stream, flags)); return 0; } static inline int ucc_pt_cudaStreamDestroy(cudaStream_t stream) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.streamDestroy(stream)); return 0; } static inline int ucc_pt_cudaMalloc(void **devptr, size_t size) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.cudaMalloc(devptr, size)); return 0; } static inline int ucc_pt_cudaMallocManaged(void **ptr, size_t size) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.cudaMallocManaged(ptr, size, cudaMemAttachGlobal)); return 0; } static inline int ucc_pt_cudaFree(void *devptr) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.cudaFree(devptr)); return 0; } static inline int ucc_pt_cudaMemset(void *devptr, int value, size_t count) { if (!ucc_pt_cuda_iface.available) { return 1; } CUDA_CHECK(ucc_pt_cuda_iface.cudaMemset(devptr, value, count)); return 0; } static inline int ucc_pt_cudaGetDeviceInfo(int device, std::string &info) { char pciBusId[256]; char hostname[256]; cudaDeviceProp prop; if (!ucc_pt_cuda_iface.available) { return 1; } if (gethostname(hostname, sizeof(hostname)) == 0) { info.append(hostname); info.append(" - "); } CUDA_CHECK(ucc_pt_cuda_iface.cudaGetDeviceProperties(&prop, device)); info.append(prop.name); info.append(" "); CUDA_CHECK(ucc_pt_cuda_iface.cudaDeviceGetPCIBusId(pciBusId, sizeof(pciBusId), device)); info.append(pciBusId); return 0; } #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_comm.cc0000664000175000017500000002365015133731560021603 0ustar alastairalastair#include #include #include #include #include "ucc_pt_comm.h" #include "ucc_pt_bootstrap_mpi.h" #include "ucc_perftest.h" #include "ucc_pt_cuda.h" #include "ucc_pt_rocm.h" extern "C" { #include "utils/ucc_coll_utils.h" #include "components/mc/ucc_mc.h" } ucc_pt_comm::ucc_pt_comm(ucc_pt_comm_config config) { cfg = config; bootstrap = new ucc_pt_bootstrap_mpi(); } ucc_pt_comm::~ucc_pt_comm() { delete bootstrap; } void ucc_pt_comm::set_gpu_device() { int dev_count = 0; if (ucc_pt_cudaGetDeviceCount(&dev_count) == 0 && dev_count != 0) { int dev = bootstrap->get_local_rank() % dev_count; ucc_pt_cudaSetDevice(dev); std::string info; ucc_pt_cudaGetDeviceInfo(dev, info); std::cout << std::left << std::setw(8) << "Rank " << std::setw(8) << bootstrap->get_rank() << ": " << info << std::endl; std::cout << std::right; return; } if (ucc_pt_rocmGetDeviceCount(&dev_count) == 0 && dev_count != 0) { ucc_pt_rocmSetDevice(bootstrap->get_local_rank() % dev_count); } return; } int ucc_pt_comm::get_rank() { return bootstrap->get_rank(); } int ucc_pt_comm::get_size() { return bootstrap->get_size(); } ucc_ee_h ucc_pt_comm::get_ee() { ucc_ee_params_t ee_params; ucc_status_t status; if (!ee) { if (cfg.mt == UCC_MEMORY_TYPE_CUDA) { if (ucc_pt_cudaStreamCreateWithFlags((cudaStream_t*)&stream, cudaStreamNonBlocking)) { throw std::runtime_error("failed to create CUDA stream"); } ee_params.ee_type = UCC_EE_CUDA_STREAM; ee_params.ee_context_size = sizeof(cudaStream_t); ee_params.ee_context = stream; status = ucc_ee_create(team, &ee_params, &ee); if (status != UCC_OK) { std::cerr << "failed to create UCC EE: " << ucc_status_string(status); ucc_pt_cudaStreamDestroy((cudaStream_t)stream); throw std::runtime_error(ucc_status_string(status)); } } else { std::cerr << "execution engine is not supported for given memory type" << std::endl; throw std::runtime_error("not supported"); } } return ee; } ucc_ee_executor_t *ucc_pt_comm::get_executor() { ucc_ee_executor_params_t executor_params; ucc_status_t status; if (!executor) { executor_params.mask = UCC_EE_EXECUTOR_PARAM_FIELD_TYPE; if (cfg.mt == UCC_MEMORY_TYPE_HOST) { executor_params.ee_type = UCC_EE_CPU_THREAD; } else if ( cfg.mt == UCC_MEMORY_TYPE_CUDA || cfg.mt == UCC_MEMORY_TYPE_CUDA_MANAGED) { executor_params.ee_type = UCC_EE_CUDA_STREAM; } else if (cfg.mt == UCC_MEMORY_TYPE_ROCM) { executor_params.ee_type = UCC_EE_ROCM_STREAM; } else { std::cerr << "executor is not supported for given memory type" << std::endl; throw std::runtime_error("not supported"); } status = ucc_ee_executor_init(&executor_params, &executor); if (status != UCC_OK) { throw std::runtime_error("failed to init executor"); } } return executor; } ucc_team_h ucc_pt_comm::get_team() { return team; } ucc_context_h ucc_pt_comm::get_context() { return context; } ucc_status_t ucc_pt_comm::init() { ucc_lib_config_h lib_config; ucc_context_config_h ctx_config; ucc_lib_params_t lib_params; ucc_context_params_t ctx_params; ucc_team_params_t team_params; ucc_status_t st; std::string cfg_mod; ee = nullptr; executor = nullptr; stream = nullptr; onesided_buf = nullptr; if (cfg.mt != UCC_MEMORY_TYPE_HOST) { set_gpu_device(); } UCCCHECK_GOTO(ucc_lib_config_read("PERFTEST", nullptr, &lib_config), exit_err, st); std::memset(&lib_params, 0, sizeof(ucc_lib_params_t)); lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; lib_params.thread_mode = UCC_THREAD_SINGLE; UCCCHECK_GOTO(ucc_init(&lib_params, lib_config, &lib), free_lib_config, st); if (UCC_OK != ucc_mc_available(cfg.mt)) { std::cerr << "selected memory type " << ucc_mem_type_str(cfg.mt) << " is not available" << std::endl; return UCC_ERR_INVALID_PARAM; } onesided_buf = ucc_calloc(1024, bootstrap->get_size(), "onesided_buf"); UCCCHECK_GOTO(ucc_context_config_read(lib, NULL, &ctx_config), free_lib, st); cfg_mod = std::to_string(bootstrap->get_size()); UCCCHECK_GOTO(ucc_context_config_modify(ctx_config, NULL, "ESTIMATED_NUM_EPS", cfg_mod.c_str()), free_ctx_config, st); cfg_mod = std::to_string(bootstrap->get_ppn()); UCCCHECK_GOTO(ucc_context_config_modify(ctx_config, NULL, "ESTIMATED_NUM_PPN", cfg_mod.c_str()), free_ctx_config, st); cfg_mod = std::to_string(bootstrap->get_local_rank()); UCCCHECK_GOTO(ucc_context_config_modify(ctx_config, NULL, "NODE_LOCAL_ID", cfg_mod.c_str()), free_ctx_config, st); std::memset(&ctx_params, 0, sizeof(ucc_context_params_t)); ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB | UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS; ctx_params.type = UCC_CONTEXT_SHARED; ctx_params.oob = bootstrap->get_context_oob(); ucc_mem_map_t map_segments[1]; map_segments[0].address = onesided_buf; map_segments[0].len = 1024; ctx_params.mem_params.segments = map_segments; ctx_params.mem_params.n_segments = 1; UCCCHECK_GOTO(ucc_context_create(lib, &ctx_params, ctx_config, &context), free_ctx_config, st); team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | UCC_TEAM_PARAM_FIELD_OOB; team_params.oob = bootstrap->get_team_oob(); team_params.ep = bootstrap->get_rank(); team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; UCCCHECK_GOTO(ucc_team_create_post(&context, 1, &team_params, &team), free_ctx, st); do { st = ucc_team_create_test(team); if (st == UCC_INPROGRESS) { ucc_context_progress(context); } } while(st == UCC_INPROGRESS); UCCCHECK_GOTO(st, free_ctx, st); ucc_context_config_release(ctx_config); ucc_lib_config_release(lib_config); return UCC_OK; free_ctx: ucc_context_destroy(context); free_ctx_config: ucc_context_config_release(ctx_config); free_lib: ucc_finalize(lib); free_lib_config: ucc_lib_config_release(lib_config); exit_err: return st; } ucc_status_t ucc_pt_comm::finalize() { ucc_status_t status; if (ee) { status = ucc_ee_destroy(ee); if (status != UCC_OK) { std::cerr << "ucc ee destroy error: " << ucc_status_string(status); } if (cfg.mt == UCC_MEMORY_TYPE_CUDA) { ucc_pt_cudaStreamDestroy((cudaStream_t)stream); } else { std::cerr << "execution engine is not supported for given memory type" << std::endl; throw std::runtime_error("not supported"); } } if (executor) { status = ucc_ee_executor_finalize(executor); if (status != UCC_OK) { std::cerr << "ucc executor finalize error: " << ucc_status_string(status); } } do { status = ucc_team_destroy(team); if (status == UCC_INPROGRESS) { ucc_context_progress(context); } } while (status == UCC_INPROGRESS); if (status != UCC_OK) { std::cerr << "ucc team destroy error: " << ucc_status_string(status); } ucc_context_destroy(context); if (onesided_buf) { ucc_free(onesided_buf); } ucc_finalize(lib); return UCC_OK; } ucc_status_t ucc_pt_comm::barrier() { ucc_coll_args_t args; ucc_coll_req_h req; args.mask = 0; args.coll_type = UCC_COLL_TYPE_BARRIER; ucc_collective_init(&args, &req, team); ucc_collective_post(req); do { ucc_context_progress(context); } while (ucc_collective_test(req) == UCC_INPROGRESS); ucc_collective_finalize(req); return UCC_OK; } ucc_status_t ucc_pt_comm::allreduce(void* in, void* out, size_t size, ucc_reduction_op_t op, ucc_datatype_t dt) { ucc_coll_args_t args; ucc_coll_req_h req; args.mask = 0; args.coll_type = UCC_COLL_TYPE_ALLREDUCE; args.op = op; args.src.info.buffer = in; args.src.info.count = size; args.src.info.datatype = dt; args.src.info.mem_type = UCC_MEMORY_TYPE_HOST; args.dst.info.buffer = out; args.dst.info.count = size; args.dst.info.datatype = dt; args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST; ucc_collective_init(&args, &req, team); ucc_collective_post(req); do { ucc_context_progress(context); } while (ucc_collective_test(req) == UCC_INPROGRESS); ucc_collective_finalize(req); return UCC_OK; } ucc_status_t ucc_pt_comm::bcast(void *data, size_t size, int root) { ucc_coll_args_t args; ucc_coll_req_h req; args.mask = 0; args.coll_type = UCC_COLL_TYPE_BCAST; args.src.info.buffer = data; args.src.info.count = size; args.src.info.datatype = UCC_DT_INT8; args.src.info.mem_type = UCC_MEMORY_TYPE_HOST; args.root = root; ucc_collective_init(&args, &req, team); ucc_collective_post(req); do { ucc_context_progress(context); } while (ucc_collective_test(req) == UCC_INPROGRESS); ucc_collective_finalize(req); return UCC_OK; } void *ucc_pt_comm::get_onesided_buf() { return onesided_buf; }openucx-ucc-ec0bc8a/tools/perf/Makefile.am0000664000175000017500000000270615133731560021037 0ustar alastairalastair# # Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow # # $HEADER$ # bin_PROGRAMS = ucc_perftest ucc_perftest_SOURCES = \ ucc_perftest.cc \ ucc_pt_config.cc \ ucc_pt_comm.cc \ ucc_pt_cuda.cc \ ucc_pt_rocm.cc \ ucc_pt_benchmark.cc \ ucc_pt_bootstrap_mpi.cc \ ucc_pt_coll.cc \ ucc_pt_coll_allgather.cc \ ucc_pt_coll_allgatherv.cc \ ucc_pt_coll_allreduce.cc \ ucc_pt_coll_alltoall.cc \ ucc_pt_coll_alltoallv.cc \ ucc_pt_coll_barrier.cc \ ucc_pt_coll_bcast.cc \ ucc_pt_coll_gather.cc \ ucc_pt_coll_gatherv.cc \ ucc_pt_coll_reduce.cc \ ucc_pt_coll_reduce_scatter.cc \ ucc_pt_coll_reduce_scatterv.cc \ ucc_pt_coll_scatter.cc \ ucc_pt_coll_scatterv.cc \ ucc_pt_op_memcpy.cc \ ucc_pt_op_reduce.cc \ ucc_pt_op_reduce_strided.cc \ generator/ucc_pt_generator_exp.cc \ generator/ucc_pt_generator_file.cc \ generator/ucc_pt_generator_traffic_matrix.cc CXX=$(MPICXX) LD=$(MPICXX) ucc_perftest_CPPFLAGS = $(BASE_CPPFLAGS) ucc_perftest_CXXFLAGS = -std=gnu++11 $(BASE_CXXFLAGS) ucc_perftest_LDFLAGS = -Wl,--rpath-link=${UCS_LIBDIR} ucc_perftest_LDADD = $(UCC_TOP_BUILDDIR)/src/libucc.la -ldl openucx-ucc-ec0bc8a/tools/perf/ucc_pt_op_reduce.cc0000664000175000017500000000513315133731560022611 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_op_reduce::ucc_pt_op_reduce(ucc_datatype_t dt, ucc_memory_type mt, ucc_reduction_op_t op, int nbufs, ucc_pt_comm *communicator, ucc_pt_generator_base *generator) : ucc_pt_coll(communicator, generator) { has_inplace_ = false; has_reduction_ = true; has_range_ = true; has_bw_ = true; if (nbufs == UCC_PT_DEFAULT_N_BUFS) { nbufs = 2; } if (nbufs < 2) { throw std::runtime_error("dt reduce op requires at least 2 bufs"); } if (nbufs > UCC_EE_EXECUTOR_NUM_BUFS) { throw std::runtime_error("dt reduce op supports up to " + std::to_string(UCC_EE_EXECUTOR_NUM_BUFS) + " bufs"); } data_type = dt; mem_type = mt; reduce_op = op; num_bufs = nbufs; } ucc_status_t ucc_pt_op_reduce::init_args(ucc_pt_test_args_t &test_args) { ucc_ee_executor_task_args_t &args = test_args.executor_args; size_t dt_size = ucc_dt_size(data_type); size_t size = generator->get_src_count() * dt_size; ucc_status_t st; int i; UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, size, mem_type), exit, st); UCCCHECK_GOTO(ucc_pt_alloc(&src_header, size * num_bufs, mem_type), free_dst, st); args.task_type = UCC_EE_EXECUTOR_TASK_REDUCE; args.reduce.dst = dst_header->addr; args.reduce.n_srcs = num_bufs; args.reduce.count = generator->get_src_count(); args.reduce.dt = data_type; args.reduce.op = reduce_op; args.flags = 0; for (i = 0; i < num_bufs; i++) { args.reduce.srcs[i] = PTR_OFFSET(src_header->addr, i * size); } return UCC_OK; free_dst: ucc_pt_free(dst_header); exit: return st; } float ucc_pt_op_reduce::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_ee_executor_task_args_t &args = test_args.executor_args; float S = args.reduce.count * ucc_dt_size(data_type); return (num_bufs + 1) * (S / time_ms) / 1000.0; } void ucc_pt_op_reduce::free_args(ucc_pt_test_args_t &test_args) { ucc_pt_free(src_header); ucc_pt_free(dst_header); } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_comm.h0000664000175000017500000000217115133731560021440 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_COMM_H #define UCC_PT_COMM_H #include #include "ucc_pt_config.h" #include "ucc_pt_bootstrap.h" #include "ucc_pt_bootstrap_mpi.h" extern "C" { #include "components/ec/ucc_ec.h" } class ucc_pt_comm { ucc_pt_comm_config cfg; ucc_lib_h lib; ucc_context_h context; ucc_team_h team; void *stream; ucc_ee_h ee; ucc_ee_executor_t *executor; ucc_pt_bootstrap *bootstrap; void set_gpu_device(); void *onesided_buf; public: ucc_pt_comm(ucc_pt_comm_config config); int get_rank(); int get_size(); ucc_ee_executor_t* get_executor(); ucc_ee_h get_ee(); ucc_team_h get_team(); ucc_context_h get_context(); void *get_onesided_buf(); ~ucc_pt_comm(); ucc_status_t init(); ucc_status_t barrier(); ucc_status_t allreduce(void* in, void *out, size_t size, ucc_reduction_op_t op, ucc_datatype_t dt); ucc_status_t bcast(void *data, size_t size, int root); ucc_status_t finalize(); }; #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_coll_gather.cc0000664000175000017500000000637615133731560023141 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "ucc_pt_coll.h" #include "ucc_perftest.h" #include #include #include ucc_pt_coll_gather::ucc_pt_coll_gather(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, bool is_persistent, int root_shift, ucc_pt_comm *communicator, ucc_pt_generator_base *generator): ucc_pt_coll(communicator, generator) { has_inplace_ = true; has_reduction_ = false; has_range_ = true; has_bw_ = true; root_shift_ = root_shift; coll_args.mask = 0; coll_args.flags = 0; coll_args.coll_type = UCC_COLL_TYPE_GATHER; coll_args.src.info.datatype = dt; coll_args.src.info.mem_type = mt; coll_args.dst.info.datatype = dt; coll_args.dst.info.mem_type = mt; if (is_inplace) { coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } if (is_persistent) { coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } } ucc_status_t ucc_pt_coll_gather::init_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; size_t dt_size = ucc_dt_size(coll_args.src.info.datatype); ucc_status_t st_src, st_dst; bool is_root; coll_args.root = test_args.coll_args.root; args = coll_args; args.dst.info.count = generator->get_dst_count(); args.src.info.count = generator->get_src_count(); is_root = (comm->get_rank() == args.root); if (is_root || root_shift_) { UCCCHECK_GOTO(ucc_pt_alloc(&dst_header, generator->get_dst_count() * dt_size, args.dst.info.mem_type), exit, st_dst); args.dst.info.buffer = dst_header->addr; } if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { UCCCHECK_GOTO( ucc_pt_alloc(&src_header, generator->get_src_count() * dt_size, args.src.info.mem_type), free_dst, st_src); args.src.info.buffer = src_header->addr; } return UCC_OK; free_dst: if ((is_root || root_shift_) && st_dst == UCC_OK) { ucc_pt_free(dst_header); } return st_src; exit: return st_dst; } float ucc_pt_coll_gather::get_bw(float time_ms, int grsize, ucc_pt_test_args_t test_args) { ucc_coll_args_t &args = test_args.coll_args; float N = grsize - 1; float S = args.src.info.count * ucc_dt_size(args.src.info.datatype); return (S * N) / time_ms / 1000.0; } void ucc_pt_coll_gather::free_args(ucc_pt_test_args_t &test_args) { ucc_coll_args_t &args = test_args.coll_args; bool is_root = (comm->get_rank() == args.root); if (!is_root || !UCC_IS_INPLACE(args) || root_shift_) { ucc_pt_free(src_header); } if (is_root || root_shift_) { ucc_pt_free(dst_header); } } openucx-ucc-ec0bc8a/tools/perf/ucc_pt_benchmark.h0000664000175000017500000000225715133731560022444 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_PT_BENCH_H #define UCC_PT_BENCH_H #include "ucc_pt_config.h" #include "ucc_pt_coll.h" #include "generator/ucc_pt_generator.h" #include "ucc_pt_comm.h" #include "utils/ucc_coll_utils.h" #include class ucc_pt_benchmark { ucc_pt_benchmark_config config; ucc_pt_comm *comm; ucc_pt_coll *coll; ucc_pt_generator_base *generator; void print_header(); void print_time(size_t count, ucc_pt_test_args_t args, double time_avg, double time_min, double time_max); public: ucc_pt_benchmark(ucc_pt_benchmark_config cfg, ucc_pt_comm *communicator); ucc_status_t run_bench() noexcept; ucc_status_t run_single_coll_test(ucc_coll_args_t args, int nwarmup, int niter, double &time) noexcept; ucc_status_t run_single_executor_test(ucc_ee_executor_task_args_t args, int nwarmup, int niter, double &time) noexcept; ~ucc_pt_benchmark(); }; #endif openucx-ucc-ec0bc8a/tools/perf/ucc_pt_bootstrap_mpi.cc0000664000175000017500000000315115133731560023524 0ustar alastairalastair#include "ucc_pt_bootstrap_mpi.h" static ucc_status_t mpi_oob_allgather(void *sbuf, void *rbuf, size_t msglen, void *coll_info, void **req) { MPI_Comm comm = (MPI_Comm)(uintptr_t)coll_info; MPI_Request request; MPI_Iallgather(sbuf, msglen, MPI_BYTE, rbuf, msglen, MPI_BYTE, comm, &request); *req = (void *)(uintptr_t)request; return UCC_OK; } static ucc_status_t mpi_oob_allgather_test(void *req) { MPI_Request request = (MPI_Request)(uintptr_t)req; int completed; MPI_Test(&request, &completed, MPI_STATUS_IGNORE); return completed ? UCC_OK : UCC_INPROGRESS; } static ucc_status_t mpi_oob_allgather_free(void *req) { return UCC_OK; } ucc_pt_bootstrap_mpi::ucc_pt_bootstrap_mpi() { MPI_Init(NULL, NULL); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &size); context_oob.coll_info = (void*)(uintptr_t)MPI_COMM_WORLD; context_oob.allgather = mpi_oob_allgather; context_oob.req_test = mpi_oob_allgather_test; context_oob.req_free = mpi_oob_allgather_free; context_oob.n_oob_eps = size; context_oob.oob_ep = rank; team_oob.coll_info = (void*)(uintptr_t)MPI_COMM_WORLD; team_oob.allgather = mpi_oob_allgather; team_oob.req_test = mpi_oob_allgather_test; team_oob.req_free = mpi_oob_allgather_free; team_oob.n_oob_eps = size; team_oob.oob_ep = rank; } int ucc_pt_bootstrap_mpi::get_rank() { return rank; } int ucc_pt_bootstrap_mpi::get_size() { return size; } ucc_pt_bootstrap_mpi::~ucc_pt_bootstrap_mpi() { MPI_Finalize(); } openucx-ucc-ec0bc8a/.github/0000775000175000017500000000000015133731560016242 5ustar alastairalastairopenucx-ucc-ec0bc8a/.github/workflows/0000775000175000017500000000000015133731560020277 5ustar alastairalastairopenucx-ucc-ec0bc8a/.github/workflows/asan-test.yaml0000664000175000017500000000311715133731560023064 0ustar alastairalastairname: ASAN Tests on: [push, pull_request] env: OPEN_UCX_LINK: https://github.com/openucx/ucx OPEN_UCX_BRANCH: master CLANG_VER: 17 jobs: gtest-asan: runs-on: ubuntu-22.04 steps: - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends wget gpg # Setup LLVM repository sudo mkdir -p /etc/apt/keyrings wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo gpg --dearmor -o /etc/apt/keyrings/llvm.gpg echo "deb [signed-by=/etc/apt/keyrings/llvm.gpg] http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VER} main" | sudo tee /etc/apt/sources.list.d/llvm.list sudo apt-get update sudo apt-get install -y --no-install-recommends clang-${CLANG_VER} clang++-${CLANG_VER} libclang-rt-${CLANG_VER}-dev - name: Get UCX run: git clone ${OPEN_UCX_LINK} -b ${OPEN_UCX_BRANCH} /tmp/ucx - name: Build UCX run: | cd /tmp/ucx && ./autogen.sh CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./contrib/configure-release --without-java --without-go --disable-numa --prefix $PWD/install make -j install - uses: actions/checkout@v1 - name: Run gtest ASAN run: | export ASAN_OPTIONS=fast_unwind_on_malloc=0:detect_leaks=1:print_suppressions=0 export LSAN_OPTIONS=report_objects=1 ./autogen.sh CFLAGS="-fsanitize=address" CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./configure --prefix=/tmp/ucc/install --with-ucx=/tmp/ucx/install --enable-gtest make -j install ./test/gtest/gtest openucx-ucc-ec0bc8a/.github/workflows/main.yaml0000664000175000017500000001110115133731560022101 0ustar alastairalastairname: OpenMPI tests on: [push, pull_request] env: OPEN_UCX_LINK: https://github.com/openucx/ucx OPEN_UCX_BRANCH: master OPEN_MPI_LINK: https://github.com/open-mpi/ompi.git OPEN_MPI_BRANCH: v5.0.x IMB_LINK: https://github.com/intel/mpi-benchmarks.git IMB_COLLS: allgather,allgatherv,allreduce,alltoall,alltoallv,barrier,bcast,gather,gatherv,reduce,reduce_scatter,reduce_scatter_block,scatter,scatterv jobs: tests: runs-on: ubuntu-latest steps: - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends doxygen doxygen-latex - name: Get UCX run: git clone ${OPEN_UCX_LINK} -b ${OPEN_UCX_BRANCH} /tmp/ucx - name: Build UCX run: | cd /tmp/ucx && ./autogen.sh ./contrib/configure-release --without-java --without-go --disable-numa --prefix $PWD/install make -j install - uses: actions/checkout@v1 - name: Build UCC run: | ./autogen.sh ./configure --prefix=/tmp/ucc/install --enable-gtest --with-ucx=/tmp/ucx/install make -j`nproc` install make gtest - name: Run ucc_info run: | /tmp/ucc/install/bin/ucc_info -vc - name: Run CMake tests run: | set -e cmake -S test/cmake -B /tmp/cmake-ucc -DCMAKE_PREFIX_PATH=/tmp/ucc/install cd /tmp/cmake-ucc cmake --build . ./test_ucc - name: Get OMPI run: | git clone ${OPEN_MPI_LINK} -b ${OPEN_MPI_BRANCH} /tmp/ompi cd /tmp/ompi git submodule update --init --recursive - name: Build OMPI run: > cd /tmp/ompi ./autogen.pl --exclude pml-cm,mtl,coll-adapt,coll-han,coll-inter,coll-ftagree ./configure --prefix=/tmp/ompi/install --enable-mpirun-prefix-by-default --disable-mpi-fortran --disable-man-pages --with-ucx=/tmp/ucx/install --with-ucc=/tmp/ucc/install make -j install - name: Build ucc_perftest (with OMPI) run: | set -e CC=/tmp/ompi/install/bin/mpicc CXX=/tmp/ompi/install/bin/mpicxx \ ./configure --prefix=/tmp/ucc/install --enable-gtest --with-ucx=/tmp/ucx/install --with-mpi=/tmp/ompi/install make -C tools/perf -j`nproc` make -C tools/perf install - name: Run ucc_perftest run: | set -e test -x /tmp/ucc/install/bin/ucc_perftest export LD_LIBRARY_PATH=/tmp/ucc/install/lib:/tmp/ucx/install/lib:/tmp/ompi/install/lib:$LD_LIBRARY_PATH COLLS=( allgather allgatherv allreduce alltoall alltoallv barrier bcast gather gatherv reduce reduce_scatter reduce_scatterv scatterv ) for c in "${COLLS[@]}"; do echo "Running ucc_perftest -c ${c}" /tmp/ompi/install/bin/mpirun \ -np 4 -H localhost:4 \ --bind-to none \ --mca pml ucx \ --mca pml_ucx_tls any \ --mca pml_ucx_devices any \ --mca coll_ucc_enable 0 \ -x LD_LIBRARY_PATH \ -x UCC_LOG_LEVEL=info \ -x UCC_TLS=ucp \ -x UCC_CONFIG_FILE= \ /tmp/ucc/install/bin/ucc_perftest \ -c "${c}" -m host -d float32 -b 1024 -e 1024 -n 10 -w 2 done INPLACE_COLLS=( allgather allgatherv allreduce gather gatherv reduce reduce_scatter reduce_scatterv ) for c in "${INPLACE_COLLS[@]}"; do echo "Running ucc_perftest (inplace) -c ${c}" /tmp/ompi/install/bin/mpirun \ -np 4 -H localhost:4 \ --bind-to none \ --mca pml ucx \ --mca pml_ucx_tls any \ --mca pml_ucx_devices any \ --mca coll_ucc_enable 0 \ -x LD_LIBRARY_PATH \ -x UCC_LOG_LEVEL=info \ -x UCC_TLS=ucp \ -x UCC_CONFIG_FILE= \ /tmp/ucc/install/bin/ucc_perftest \ -c "${c}" -m host -d float32 -b 1024 -e 1024 -n 10 -w 2 -i done - name: Get IMB run: git clone ${IMB_LINK} /tmp/imb - name: Build IMB run: | cd /tmp/imb make CC=/tmp/ompi/install/bin/mpicc CXX=/tmp/ompi/install/bin/mpicxx CPPFLAGS="-DCHECK=1" -j IMB-MPI1 - name: Run IMB-DCHECK run: > /tmp/ompi/install/bin/mpirun -np 8 -H localhost:8 --bind-to none --mca pml ucx --mca pml_ucx_tls any --mca pml_ucx_devices any --mca coll_ucc_priority 100 --mca coll_ucc_enable 1 /tmp/imb/IMB-MPI1 ${IMB_COLLS} -iter 10 -iter_policy off openucx-ucc-ec0bc8a/.github/workflows/docs.yaml0000664000175000017500000000114015133731560022107 0ustar alastairalastairname: Docs on: [push, pull_request] jobs: docs: runs-on: ubuntu-latest steps: - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends doxygen doxygen-latex cm-super texlive-fonts-recommended - uses: actions/checkout@v1 - name: Build UCC docs run: | ./autogen.sh ./configure --with-docs-only make docs - name: Upload docs uses: actions/upload-artifact@v4 with: name: docs path: ${{ github.workspace }}/docs/doxygen-doc/ucc.pdf retention-days: 7 openucx-ucc-ec0bc8a/.github/workflows/clang-tidy-rocm.yaml0000664000175000017500000000537615133731560024167 0ustar alastairalastairname: Linter-ROCM on: [push, pull_request] env: OPEN_UCX_LINK: https://github.com/openucx/ucx OPEN_UCX_BRANCH: master CLANG_VER: 17 ROCM_VER: 5.6.1 LIBRARY_PATH: /tmp/ucx/install/lib LD_LIBRARY_PATH: /tmp/ucx/install/lib jobs: clang-tidy: runs-on: ubuntu-22.04 steps: - name: Install dependencies run: | sudo apt-get update # Install basic dependencies sudo apt-get install -y --no-install-recommends wget gpg # Setup LLVM repository sudo mkdir -p /etc/apt/keyrings wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo gpg --dearmor -o /etc/apt/keyrings/llvm.gpg echo "deb [signed-by=/etc/apt/keyrings/llvm.gpg] http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VER} main" | sudo tee /etc/apt/sources.list.d/llvm.list # Setup ROCm repository wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | sudo gpg --dearmor -o /etc/apt/keyrings/rocm.gpg echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${ROCM_VER} jammy main" | sudo tee /etc/apt/sources.list.d/rocm.list echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600 # Update PATH for ROCm echo 'export PATH=$PATH:/opt/rocm/bin:/opt/rocm/llvm/bin' | sudo tee -a /etc/profile.d/rocm.sh source /etc/profile.d/rocm.sh # Install all required packages sudo apt-get update sudo apt-get install -y --no-install-recommends \ clang-tidy-${CLANG_VER} \ bear \ rocm-hip-sdk sudo ln -sf /opt/rocm-${ROCM_VER} /opt/rocm - name: Get UCX run: git clone ${OPEN_UCX_LINK} -b ${OPEN_UCX_BRANCH} /tmp/ucx - name: Build UCX run: | cd /tmp/ucx && ./autogen.sh CC=gcc CXX=g++ ./contrib/configure-release --without-java --without-go --disable-numa --prefix $PWD/install --with-rocm=/opt/rocm make -j install - uses: actions/checkout@v1 - name: Build UCC run: | ./autogen.sh CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./configure --prefix=/tmp/ucc/install --with-ucx=/tmp/ucx/install --with-rocm=/opt/rocm --with-rccl=/opt/rocm --enable-assert bear --output /tmp/compile_commands.json -- make -j - name: Run clang-tidy run: | echo "Workspace: ${GITHUB_WORKSPACE}" cd ${GITHUB_WORKSPACE} run-clang-tidy-${CLANG_VER} -p /tmp/ -header-filter='^(?!.*(${GITHUB_WORKSPACE}/src/components/ec/rocm/kernel/.*|${GITHUB_WORKSPACE}/src/components/mc/rocm/kernel/.*)).*$' "^(?!.*\.cu$).*$" 2>&1 | tee /tmp/clang_tidy.log nerrors=$(grep 'error:' /tmp/clang_tidy.log | wc -l) if [ $nerrors -ne 0 ]; then exit 125; fi openucx-ucc-ec0bc8a/.github/workflows/blossom-ci.yml0000664000175000017500000000510315133731560023070 0ustar alastairalastair# A workflow to trigger ci on hybrid infra (github + self hosted runner) name: Blossom-CI run-name: > ${{ github.event_name == 'workflow_dispatch' && format( 'Blossom CI • Jenkins Job: {0}{1} • PR #{2}', fromJson(github.event.inputs.args).job, fromJson(github.event.inputs.args).build && format(' #{0}', fromJson(github.event.inputs.args).build) || '', fromJson(github.event.inputs.args).pr ) || '' }} on: issue_comment: types: [created] workflow_dispatch: inputs: platform: description: 'runs-on argument' required: false args: description: 'argument' required: false jobs: Authorization: name: Authorization runs-on: blossom outputs: args: ${{ env.args }} # This job only runs for pull request comments if: github.event.comment.body == '/build' steps: - name: Check if comment is issued by authorized person run: blossom-ci env: OPERATION: 'AUTH' REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }} Vulnerability-scan: name: Vulnerability scan needs: [Authorization] runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v2 with: repository: ${{ fromJson(needs.Authorization.outputs.args).repo }} ref: ${{ fromJson(needs.Authorization.outputs.args).ref }} lfs: 'true' - name: Run blossom action uses: NVIDIA/blossom-action@main env: REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }} with: args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }} args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }} args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }} Job-trigger: name: Start ci job needs: [Vulnerability-scan] runs-on: blossom steps: - name: Start ci job run: blossom-ci env: OPERATION: 'START-CI-JOB' CI_SERVER: ${{ secrets.CI_SERVER }} REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} Upload-Log: name: Upload log runs-on: blossom if: github.event_name == 'workflow_dispatch' steps: - name: Jenkins log for pull request ${{ fromJson(github.event.inputs.args).pr }} (click here) run: blossom-ci env: OPERATION: 'POST-PROCESSING' CI_SERVER: ${{ secrets.CI_SERVER }} REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} openucx-ucc-ec0bc8a/.github/workflows/clang-tidy.yaml0000664000175000017500000000331115133731560023214 0ustar alastairalastairname: Linter on: [push, pull_request] env: OPEN_UCX_LINK: https://github.com/openucx/ucx OPEN_UCX_BRANCH: master CLANG_VER: 17 jobs: clang-tidy: runs-on: ubuntu-22.04 steps: - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends wget gpg # Setup LLVM repository sudo mkdir -p /etc/apt/keyrings wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo gpg --dearmor -o /etc/apt/keyrings/llvm.gpg echo "deb [signed-by=/etc/apt/keyrings/llvm.gpg] http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VER} main" | sudo tee /etc/apt/sources.list.d/llvm.list sudo apt-get update sudo apt-get install -y --no-install-recommends clang-tidy-${CLANG_VER} bear - name: Get UCX run: git clone ${OPEN_UCX_LINK} -b ${OPEN_UCX_BRANCH} /tmp/ucx - name: Build UCX run: | cd /tmp/ucx && ./autogen.sh CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./contrib/configure-release --without-java --without-go --disable-numa --prefix $PWD/install make -j install - uses: actions/checkout@v1 - name: Build UCC run: | ./autogen.sh CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./configure --prefix=/tmp/ucc/install --with-ucx=/tmp/ucx/install --enable-assert bear --output /tmp/compile_commands.json -- make -j - name: Run clang-tidy run: | echo "Workspace: ${GITHUB_WORKSPACE}" cd ${GITHUB_WORKSPACE} run-clang-tidy-${CLANG_VER} -p /tmp/ 2>&1 | tee /tmp/clang_tidy.log nerrors=$(grep 'error:' /tmp/clang_tidy.log | wc -l) if [ $nerrors -ne 0 ]; then exit 125; fi openucx-ucc-ec0bc8a/.github/workflows/codestyle.yaml0000664000175000017500000000655615133731560023172 0ustar alastairalastairname: Codestyle on: [pull_request] # Cancel in-progress runs for the same PR concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true env: GIT_CF: https://raw.githubusercontent.com/llvm/llvm-project/release/21.x/clang/tools/clang-format/git-clang-format LLVM_VERSION: 21 jobs: check-codestyle: runs-on: ubuntu-22.04 name: Check code style defaults: run: shell: bash steps: - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends wget lsb-release software-properties-common gnupg wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/llvm.asc sudo add-apt-repository -y "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${LLVM_VERSION} main" sudo apt-get update sudo apt-get install -y --no-install-recommends clang-format-${LLVM_VERSION} # Install git-clang-format (no version suffix) curl -fsSL $GIT_CF -o git-clang-format chmod +x ./git-clang-format sudo mv ./git-clang-format /usr/bin/git-clang-format - name: Checking out repository uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - name: Check commit title run: | set -eE range="remotes/origin/$GITHUB_BASE_REF..HEAD" check_title() { msg=$1 if [ ${#msg} -gt 50 ] then if ! echo $msg | grep -qP '^Merge' then echo "Commit title is too long: ${#msg}" return 1 fi fi H1="CODESTYLE|REVIEW|CORE|UTIL|TEST|API|DOCS|TOOLS|BUILD|MC|EC|SCHEDULE|TOPO" H2="CI|CL/|TL/|MC/|EC/|UCP|SHM|NCCL|SHARP|BASIC|HIER|DOCA_UROM|CUDA|CPU|EE|RCCL|ROCM|SELF|MLX5" if ! echo $msg | grep -qP '^Merge |^'"(($H1)|($H2))"'+: \w' then echo "Wrong header" return 1 fi if [ "${msg: -1}" = "." ] then echo "Dot at the end of title" return 1 fi return 0 } ok=1 for sha1 in `git log $range --format="%h"` do title=`git log -1 --format="%s" $sha1` if check_title "$title" then echo "Good commit title: '$title'" else echo "Bad commit title: '$title'" ok=0 fi echo "--------------------------------------------------" done if [ $ok -ne 1 ] then exit 1 fi - name: Check code format run: | set -eEuo pipefail echo "Commit ${{ github.event.pull_request.base.sha }}" diff=`git-clang-format --binary=clang-format-${LLVM_VERSION} --style=file --diff ${{ github.event.pull_request.base.sha }}` || true if [ "$diff" = "no modified files to format" ] || [ "$diff" = "clang-format did not modify any files" ] then echo "Format check PASS" else echo "Format check FAILED" echo "" echo "Please format your code using:" echo " git-clang-format --binary=clang-format-${LLVM_VERSION} ${{ github.event.pull_request.base.sha }}" echo "" echo "Formatting differences:" echo "$diff" fi openucx-ucc-ec0bc8a/.github/workflows/clang-tidy-nvidia.yaml0000664000175000017500000000663015133731560024473 0ustar alastairalastairname: Linter-NVIDIA on: [push, pull_request] env: OPEN_UCX_LINK: https://github.com/openucx/ucx OPEN_UCX_BRANCH: master HPCX_LINK: https://content.mellanox.com/hpc/hpc-x/v2.22.1rc4/hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz CLANG_VER: 17 MLNX_OFED_VER: 24.10-2.1.8.0 CUDA_VER: 12-8 LIBRARY_PATH: /tmp/ucx/install/lib LD_LIBRARY_PATH: /tmp/ucx/install/lib jobs: clang-tidy: runs-on: ubuntu-22.04 steps: - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends wget gpg # Setup LLVM repository sudo mkdir -p /etc/apt/keyrings wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo gpg --dearmor -o /etc/apt/keyrings/llvm.gpg echo "deb [signed-by=/etc/apt/keyrings/llvm.gpg] http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VER} main" | sudo tee /etc/apt/sources.list.d/llvm.list sudo apt-get update sudo apt-get install -y --no-install-recommends clang-tidy-${CLANG_VER} bear clang-${CLANG_VER} clang++-${CLANG_VER} - name: Install extra rdma dependencies run: | wget --no-verbose http://content.mellanox.com/ofed/MLNX_OFED-${MLNX_OFED_VER}/MLNX_OFED_LINUX-${MLNX_OFED_VER}-ubuntu22.04-x86_64.tgz sudo tar -xvzf MLNX_OFED_LINUX-${MLNX_OFED_VER}-ubuntu22.04-x86_64.tgz sudo chmod -R a+rwx MLNX_OFED_LINUX-${MLNX_OFED_VER}-ubuntu22.04-x86_64 sudo MLNX_OFED_LINUX-${MLNX_OFED_VER}-ubuntu22.04-x86_64/mlnxofedinstall --skip-unsupported-devices-check --user-space-only --without-fw-update --force --basic -vvv - name: Install extra cuda dependencies run: | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get install -y --no-install-recommends cuda-cudart-dev-${CUDA_VER} cuda-nvcc-${CUDA_VER} cuda-nvml-dev-${CUDA_VER} - name: Get UCX run: git clone ${OPEN_UCX_LINK} -b ${OPEN_UCX_BRANCH} /tmp/ucx - name: Build UCX run: | cd /tmp/ucx && ./autogen.sh CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./contrib/configure-release --without-java --without-go --disable-numa --prefix $PWD/install make -j install - name: Download HPCX run: | cd /tmp wget --no-verbose ${HPCX_LINK} tar xjf hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64.tbz mv hpcx-v2.22.1-gcc-doca_ofed-ubuntu22.04-cuda12-x86_64 hpcx - uses: actions/checkout@v1 - name: Build UCC run: | ./autogen.sh CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./configure --with-tls=ucp,mlx5,cuda,self,sharp --with-sharp=/tmp/hpcx/sharp --prefix=/tmp/ucc/install --with-ucx=/tmp/ucx/install --with-cuda=/usr/local/cuda --with-nvcc-gencode="-gencode=arch=compute_80,code=sm_80" --enable-assert bear --output /tmp/compile_commands.json -- make -j - name: Run clang-tidy run: | echo "Workspace: ${GITHUB_WORKSPACE}" cd ${GITHUB_WORKSPACE} run-clang-tidy-${CLANG_VER} -p /tmp/ -header-filter='^(?!.*(${GITHUB_WORKSPACE}/src/components/ec/cuda/kernel/.*|${GITHUB_WORKSPACE}/src/components/mc/cuda/kernel/.*)).*$' "^(?!.*\.cu$).*$" 2>&1 | tee /tmp/clang_tidy.log nerrors=$(grep 'error:' /tmp/clang_tidy.log | wc -l) if [ $nerrors -ne 0 ]; then exit 125; fi openucx-ucc-ec0bc8a/.github/PULL_REQUEST_TEMPLATE.md0000664000175000017500000000046615133731560022051 0ustar alastairalastair## What _Describe what this PR is doing._ ## Why ? _Justification for the PR. If there is existing issue/bug please reference. For bug fixes why and what can be merged in a single item._ ## How ? _It is optional but for complex PRs please provide information about the design, architecture, approach, etc._ openucx-ucc-ec0bc8a/README.md0000664000175000017500000001061315133731560016162 0ustar alastairalastair# Unified Collective Communication (UCC) UCC is a collective communication operations API and library that is flexible, complete, and feature-rich for current and emerging programming models and runtimes. - [Design Goals](#design-goals) - [API](https://openucx.github.io/ucc/) - [Building](#compiling-and-installing) - [Community](#community) - [Contributing](#contributing) - [License](#license) - [Publication](#publication) ## Design Goals * Highly scalable and performant collectives for HPC, AI/ML and I/O workloads * Nonblocking collective operations that cover a variety of programming models * Flexible resource allocation model * Support for relaxed ordering model * Flexible synchronous model * Repetitive collective operations (init once and invoke multiple times) * Hardware collectives are a first-class citizen ### UCC Component Architecture ![](docs/images/ucc_components.png) ## Contributing Thanks for your interest in contributing to UCC, please see our technical and legal guidelines in the [contributing](CONTRIBUTING.md) file. All contributors have to comply with ["Membership Voluntary Consensus Standard"](https://ucfconsortium.org/policy/) and ["Export Compliant Contribution Submissions"](https://ucfconsortium.org/policy/) policies. ## License UCC is BSD-style licensed, as found in the [LICENSE](LICENSE) file. ## Required packages * [UCX](https://github.com/openucx/ucx) * UCC uses utilities provided by UCX's UCS component * [CUDA](https://developer.nvidia.com/cuda-toolkit) (optional) * UCC supports CUDA collectives. To compile with CUDA support, install [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.0 or above. * [HIP](https://rocmdocs.amd.com/en/latest/Programming_Guides/HIP-GUIDE.html) (optional) * UCC supports AMD GPUs using HIP. Instructions for installing ROCM/HIP can be found at [AMD ROCM](https://rocmdocs.amd.com/en/latest/Installation_Guide/Installation_new.html). * Doxygen * UCC uses Doxygen for generating API documentation ## Compiling and Installing ### Developer's Build ```sh $ ./autogen.sh $ ./configure --prefix= --with-ucx= $ make ``` ### Build Documentation ```sh $ ./autogen.sh $ ./configure --prefix= --with-docs-only $ make docs ``` ### Open MPI and UCC collectives #### Compile UCX  ```sh $ git clone https://github.com/openucx/ucx $ cd ucx $ ./autogen.sh; ./configure --prefix=; make -j install ``` #### Compile UCC ```sh $ git clone https://github.com/openucx/ucc $ cd ucc $ ./autogen.sh; ./configure --prefix= --with-ucx=; make -j install ``` #### Compile Open MPI  ```sh $ git clone https://github.com/open-mpi/ompi $ cd ompi $ ./autogen.pl; ./configure --prefix= --with-ucx= --with-ucc=; make -j install ``` #### Run MPI programs ```sh $ mpirun -np 2 --mca coll_ucc_enable 1 --mca coll_ucc_priority 100 ./my_mpi_app ``` #### Run OpenSHMEM programs ```sh $ mpirun -np 2 --mca scoll_ucc_enable 1 --mca scoll_ucc_priority 100 ./my_openshmem_app ``` ### SUPPORTED Transports * UCX/UCP - InfiniBand, ROCE, Cray Gemini and Aries, Shared Memory * SHARP * CUDA * NCCL * RCCL * MLX5 ### Publication To cite UCC in a publication, please use the following BibTex entry: ``` @inproceedings{DBLP:conf/hoti/VenkataPLBALBDS24, author = {Manjunath Gorentla Venkata and Valentine Petrov and Sergey Lebedev and Devendar Bureddy and Ferrol Aderholdt and Joshua Ladd and Gil Bloch and Mike Dubman and Gilad Shainer}, title = {Unified Collective Communication {(UCC):} An Unified Library for CPU, GPU, and {DPU} Collectives}, booktitle = {{IEEE} Symposium on High-Performance Interconnects, {HOTI} 2024, Albuquerque, NM, USA, August 21-23, 2024}, pages = {37--46}, publisher = {{IEEE}}, year = {2024}, url = {https://doi.org/10.1109/HOTI63208.2024.00018}, doi = {10.1109/HOTI63208.2024.00018}, timestamp = {Thu, 19 Sep 2024 11:00:54 +0200}, biburl = {https://dblp.org/rec/conf/hoti/VenkataPLBALBDS24.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } ``` openucx-ucc-ec0bc8a/.azure/0000775000175000017500000000000015133731560016106 5ustar alastairalastairopenucx-ucc-ec0bc8a/.azure/azure-pipelines-pr.yml0000664000175000017500000000605115133731560022366 0ustar alastairalastair# See https://aka.ms/yaml # This pipeline to be run on PRs trigger: none pr: branches: include: - master - v*.*.x paths: exclude: - .gitignore - .readthedocs.yaml - contrib/pr_merge_check.py - docs/source - docs/CodeStyle.md - docs/LoggingStyle.md - docs/OptimizationStyle.md - README.md - NEWS - AUTHORS resources: containers: - container: fedora image: rdmz-harbor.rdmz.labs.mlnx/ucx/fedora33:1 stages: - stage: Codestyle jobs: # Check that the code is formatted according to the code style guidelines - job: format displayName: format code pool: name: MLNX demands: - ucx_docker -equals yes container: fedora steps: - checkout: self clean: true fetchDepth: 100 - bash: | set -x git log -1 HEAD git log -1 HEAD^ BASE_SOURCEVERSION=$(git rev-parse HEAD^) echo "Checking code format on diff ${BASE_SOURCEVERSION}..${BUILD_SOURCEVERSION}" git-clang-format --diff ${BASE_SOURCEVERSION} ${BUILD_SOURCEVERSION} > format.patch echo "Generated patch file:" cat format.patch if [ "`cat format.patch`" = "no modified files to format" ]; then exit fi git apply format.patch if ! git diff --quiet --exit-code then url="https://github.com/openucx/ucx/wiki/Code-style-checking" echo "##vso[task.complete result=SucceededWithIssues;]DONE'Code is not formatted according to the code style, see $url for more info.'" echo "##vso[task.logissue type=warning]'Code is not formatted'" fi - stage: Test jobs: - job: Gtest timeoutInMinutes: 120 workspace: clean: all pool: name: MLNX demands: - ucx_bf -equals yes steps: - checkout: self - bash: | set -eE git clone --depth 1 -b master https://github.com/openucx/ucx.git ucx cd ucx ./autogen.sh mkdir -p ucx_build_dir cd ucx_build_dir ../configure --without-java --prefix=$(Build.Repository.LocalPath)/install_ucx gcc -v make -s -j `nproc` make install displayName: Build ucx artifact timeoutInMinutes: 40 - bash: | set -xEe ./autogen.sh mkdir -p build cd build ../configure --with-ucx=$(Build.Repository.LocalPath)/install_ucx \ --prefix=$(Build.Repository.LocalPath)/install --enable-gtest make -j install displayName: Build ucc artifact timeoutInMinutes: 60 - bash: | cd build make gtest displayName: Launch Gtest timeoutInMinutes: 120 openucx-ucc-ec0bc8a/contrib/0000775000175000017500000000000015133731560016342 5ustar alastairalastairopenucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/0000775000175000017500000000000015133731560022522 5ustar alastairalastairopenucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/common/0000775000175000017500000000000015133731560024012 5ustar alastairalastairopenucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/common/urom_ucc.h0000664000175000017500000001516215133731560026004 0ustar alastairalastair/* * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. * * This software product is a proprietary product of NVIDIA CORPORATION & * AFFILIATES (the "Company") and all right, title, and interest in and to the * software product, including all associated intellectual property rights, are * and shall remain exclusively with the Company. * * This software product is governed by the End User License Agreement * provided with the software product. * */ #ifndef UROM_UCC_H_ #define UROM_UCC_H_ #include #include #ifdef __cplusplus extern "C" { #endif /* UCC serializing next raw, iter points to the offset place and returns the buffer start */ #define urom_ucc_serialize_next_raw(_iter, _type, _offset) \ ({ \ _type *_result = (_type *)(*(_iter)); \ *(_iter) = UCS_PTR_BYTE_OFFSET(*(_iter), _offset); \ _result; \ }) /* UCC command types */ enum urom_worker_ucc_cmd_type { UROM_WORKER_CMD_UCC_LIB_CREATE, /* UCC library create command */ UROM_WORKER_CMD_UCC_LIB_DESTROY, /* UCC library destroy command */ UROM_WORKER_CMD_UCC_CONTEXT_CREATE, /* UCC context create command */ UROM_WORKER_CMD_UCC_CONTEXT_DESTROY, /* UCC context destroy command */ UROM_WORKER_CMD_UCC_TEAM_CREATE, /* UCC team create command */ UROM_WORKER_CMD_UCC_COLL, /* UCC collective create command */ UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL, /* UCC passive data channel command */ }; /* * UCC library create command structure * * Input parameters for creating the library handle. The semantics of the * parameters are defined by ucc.h On successful completion of * urom_worker_cmd_ucc_lib_create, The UROM worker will generate a notification * on the notification queue. This notification has reference to local library * handle on the worker. The implementation can choose to create shadow handles * or safely pack the library handle on the BlueCC worker to the AEU. */ struct urom_worker_cmd_ucc_lib_create { void *params; /* UCC library parameters */ }; /* UCC context create command structure */ struct urom_worker_cmd_ucc_context_create { union { int64_t start; /* The started index */ int64_t *array; /* Set stride to <= 0 if array is used */ }; int64_t stride; /* Set number of strides */ int64_t size; /* Set stride size */ void *base_va; /* Shared buffer address */ uint64_t len; /* Buffer length */ }; /* UCC passive data channel command structure */ struct urom_worker_cmd_ucc_pass_dc { void *ucp_addr; /* UCP worker address on host */ size_t addr_len; /* UCP worker address length */ }; /* UCC context destroy command structure */ struct urom_worker_cmd_ucc_context_destroy { ucc_context_h context_h; /* UCC context pointer */ }; /* UCC team create command structure */ struct urom_worker_cmd_ucc_team_create { int64_t start; /* Team start index */ int64_t stride; /* Number of strides */ int64_t size; /* Stride size */ ucc_context_h context_h; /* UCC context */ }; /* UCC collective command structure */ struct urom_worker_cmd_ucc_coll { ucc_coll_args_t *coll_args; /* Collective arguments */ ucc_team_h team; /* UCC team */ int use_xgvmi; /* If operation uses XGVMI */ void *work_buffer; /* Work buffer */ size_t work_buffer_size; /* Buffer size */ size_t team_size; /* Team size */ }; /* UROM UCC worker command structure */ struct urom_worker_ucc_cmd { enum urom_worker_ucc_cmd_type cmd_type; uint64_t dpu_worker_id; /* DPU worker id as part of the team */ union { struct urom_worker_cmd_ucc_lib_create lib_create_cmd; /* Lib create command */ struct urom_worker_cmd_ucc_context_create context_create_cmd; /* Context create command */ struct urom_worker_cmd_ucc_context_destroy context_destroy_cmd; /* Context destroy command */ struct urom_worker_cmd_ucc_team_create team_create_cmd; /* Team create command */ struct urom_worker_cmd_ucc_coll coll_cmd; /* UCC collective command */ struct urom_worker_cmd_ucc_pass_dc pass_dc_create_cmd; /* Passive data channel command */ }; }; /* UCC notification types */ enum urom_worker_ucc_notify_type { UROM_WORKER_NOTIFY_UCC_LIB_CREATE_COMPLETE, /* Create UCC library on DPU notification */ UROM_WORKER_NOTIFY_UCC_LIB_DESTROY_COMPLETE, /* Destroy UCC library on DPU notification */ UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE, /* Create UCC context on DPU notification */ UROM_WORKER_NOTIFY_UCC_CONTEXT_DESTROY_COMPLETE, /* Destroy UCC context on DPU notification */ UROM_WORKER_NOTIFY_UCC_TEAM_CREATE_COMPLETE, /* Create UCC team on DPU notification */ UROM_WORKER_NOTIFY_UCC_COLLECTIVE_COMPLETE, /* UCC collective completion notification */ UROM_WORKER_NOTIFY_UCC_PASSIVE_DATA_CHANNEL_COMPLETE, /* UCC data channel completion notification */ }; /* UCC context create notification structure */ struct urom_worker_ucc_notify_context_create { ucc_context_h context; /* Pointer to UCC context */ }; /* UCC team create notification structure */ struct urom_worker_ucc_notify_team_create { ucc_team_h team; /* Pointer to UCC team */ }; /* UCC collective notification structure */ struct urom_worker_ucc_notify_collective { ucc_status_t status; /* UCC collective status */ }; /* UCC passive data channel notification structure */ struct urom_worker_ucc_notify_pass_dc { ucc_status_t status; /* UCC data channel status */ }; /* UROM UCC worker notification structure */ struct urom_worker_notify_ucc { enum urom_worker_ucc_notify_type notify_type; uint64_t dpu_worker_id; /* DPU worker id */ union { struct urom_worker_ucc_notify_context_create context_create_nqe; /* Context create notification */ struct urom_worker_ucc_notify_team_create team_create_nqe; /* Team create notification */ struct urom_worker_ucc_notify_collective coll_nqe; /* Collective notification */ struct urom_worker_ucc_notify_pass_dc pass_dc_nqe; /* Passive data channel notification */ }; }; typedef struct ucc_worker_key_buf { size_t src_len; size_t dst_len; char rkeys[1024]; } ucc_worker_key_buf; #ifdef __cplusplus } /* extern "C" */ #endif #endif /* UROM_UCC_H_ */ openucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/dpu/0000775000175000017500000000000015133731560023312 5ustar alastairalastairopenucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.c0000664000175000017500000024642115133731560025632 0ustar alastairalastair/* * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. * * This software product is a proprietary product of NVIDIA CORPORATION & * AFFILIATES (the "Company") and all right, title, and interest in and to the * software product, including all associated intellectual property rights, are * and shall remain exclusively with the Company. * * This software product is governed by the End User License Agreement * provided with the software product. * */ #define _GNU_SOURCE #include #include #include #include #include #include #include #include "worker_ucc.h" #include "../common/urom_ucc.h" DOCA_LOG_REGISTER(UROM::WORKER::UCC); static uint64_t plugin_version = 0x01; /* UCC plugin DPU version */ static volatile uint64_t *queue_front; /* Front queue node */ static volatile uint64_t *queue_tail; /* Tail queue node */ static volatile uint64_t *queue_size; /* Queue size */ static int ucc_component_enabled; /* Shared between worker threads */ static pthread_t context_progress_thread; /* UCC progress thread context */ static uint64_t queue_lock = 0; /* Threads queue lock */ static pthread_t *progress_thread = NULL; /* Progress threads array */ /* UCC opts structure */ struct worker_ucc_opts worker_ucc_opts = { .num_progress_threads = 1, .ppw = 32, .tpp = 1, .list_size = 64, .num_psync = 128, .dpu_worker_binding_stride = 1, }; /* Progress thread arguments structure */ struct thread_args { uint64_t thread_id; /* Progress thread id */ struct urom_worker_ucc *ucc_worker; /* UCC worker context */ }; /* Determine number of cores by counting the number of lines containing "processor" in /proc/cpuinfo */ int get_ncores() { static int core_count = 0; int count = 0; FILE *fptr; char str[100]; char *pos; int index; // just read the file once and return the stored value on subsequent calls if (core_count != 0) { return core_count; } fptr = fopen("/proc/cpuinfo", "rb"); if (fptr == NULL) { printf("Failed to open /proc/cpuinfo\n"); exit(EXIT_FAILURE); } while ((fgets(str, 100, fptr)) != NULL) { index = 0; while ((pos = strstr(str + index, "processor")) != NULL) { index = (pos - str) + 1; count++; } } fclose(fptr); core_count = count; return count; } void dpu_thread_set_affinity_specific_core(int core_id) { cpu_set_t cpuset; CPU_ZERO(&cpuset); if (core_id >=0 && core_id < get_ncores()) { CPU_SET(core_id, &cpuset); pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); } else { printf("bad core id: %d\n", core_id); exit(-1); } } void dpu_thread_set_affinity(int thread_id) { int coreid = thread_id; int do_stride = worker_ucc_opts.dpu_worker_binding_stride; int num_threads = worker_ucc_opts.num_progress_threads; int num_cores; int stride; cpu_set_t cpuset; num_cores = get_ncores(); stride = num_cores / num_threads; CPU_ZERO(&cpuset); if(do_stride) { stride = do_stride; if (num_threads % 2 != 0) { stride = 1; } coreid *= stride; } if (coreid >=0 && coreid < num_cores) { CPU_SET(coreid, &cpuset); pthread_setaffinity_np(progress_thread[thread_id], sizeof(cpuset), &cpuset); } } /* * Find available queue element * * @ctx_id [in]: UCC context id * @ucc_worker [in]: UCC command descriptor * @ret_qe [out]: set available queue element * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t find_qe_slot(uint64_t ctx_id, struct urom_worker_ucc *ucc_worker, struct ucc_queue_element **ret_qe) { int thread_id = ctx_id % worker_ucc_opts.num_progress_threads; uint64_t next = (queue_tail[thread_id] + 1) % worker_ucc_opts.list_size; int curr = queue_tail[thread_id]; if (next == queue_front[thread_id]) { *ret_qe = NULL; return DOCA_ERROR_FULL; } *ret_qe = &ucc_worker->queue[thread_id][curr]; if ((*ret_qe)->in_use != 0) { *ret_qe = NULL; return DOCA_ERROR_BAD_STATE; } queue_tail[thread_id] = next; return DOCA_SUCCESS; } /* * Open UCC worker plugin * * @ctx [in]: DOCA UROM worker context * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_open(struct urom_worker_ctx *ctx) { uint64_t i, j; doca_error_t result; ucs_status_t status; ucp_params_t ucp_params; ucp_config_t *ucp_config; ucp_worker_params_t worker_params; struct urom_worker_ucc *ucc_worker; if (ctx == NULL) { return DOCA_ERROR_INVALID_VALUE; } ucc_worker = calloc(1, sizeof(*ucc_worker)); if (ucc_worker == NULL) { DOCA_LOG_ERR("Failed to allocate UCC worker context"); return DOCA_ERROR_NO_MEMORY; } if (worker_ucc_opts.num_progress_threads < MIN_THREADS) { worker_ucc_opts.num_progress_threads = MIN_THREADS; DOCA_LOG_WARN("Number of threads for UCC Offload " "must be 1 or more, set to 1"); } ucc_worker->ctx_id = 0; ucc_worker->nr_connections = 0; ucc_worker->ucc_data = calloc(worker_ucc_opts.ppw * worker_ucc_opts.tpp, sizeof(struct ucc_data)); if (ucc_worker->ucc_data == NULL) { DOCA_LOG_ERR("Failed to allocate UCC worker context"); result = DOCA_ERROR_NO_MEMORY; goto ucc_free; } ucc_worker->queue = (struct ucc_queue_element **) malloc(sizeof(struct ucc_queue_element *) * worker_ucc_opts.num_progress_threads); if (ucc_worker->queue == NULL) { DOCA_LOG_ERR("Failed to allocate UCC elements queue"); result = DOCA_ERROR_NO_MEMORY; goto ucc_data_free; } for (i = 0; i < worker_ucc_opts.num_progress_threads; i++) { ucc_worker->queue[i] = calloc(worker_ucc_opts.list_size, sizeof(struct ucc_queue_element)); if (ucc_worker->queue[i] == NULL) { DOCA_LOG_ERR("Failed to allocate queue elements"); result = DOCA_ERROR_NO_MEMORY; goto queue_free; } } queue_front = (volatile uint64_t *) calloc(worker_ucc_opts.num_progress_threads, sizeof(uint64_t)); if (queue_front == NULL) { result = DOCA_ERROR_NO_MEMORY; goto queue_free; } queue_tail = (volatile uint64_t *) calloc(worker_ucc_opts.num_progress_threads, sizeof(uint64_t)); if (queue_tail == NULL) { result = DOCA_ERROR_NO_MEMORY; goto queue_front_free; } queue_size = (volatile uint64_t *) calloc(worker_ucc_opts.num_progress_threads, sizeof(uint64_t)); if (queue_size == NULL) { result = DOCA_ERROR_NO_MEMORY; goto queue_tail_free; } status = ucp_config_read(NULL, NULL, &ucp_config); if (status != UCS_OK) { DOCA_LOG_ERR("Failed to read UCP config"); goto queue_size_free; } status = ucp_config_modify(ucp_config, "PROTO_ENABLE", "y"); if (status != UCS_OK) { DOCA_LOG_ERR("Failed to read UCP config"); ucp_config_release(ucp_config); goto queue_size_free; } ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_AMO64 | UCP_FEATURE_EXPORTED_MEMH; status = ucp_init(&ucp_params, ucp_config, &ucc_worker->ucp_data.ucp_context); ucp_config_release(ucp_config); if (status != UCS_OK) { DOCA_LOG_ERR("Failed to initialized UCP"); result = DOCA_ERROR_DRIVER; goto queue_size_free; } worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; worker_params.thread_mode = UCS_THREAD_MODE_MULTI; status = ucp_worker_create(ucc_worker->ucp_data.ucp_context, &worker_params, &ucc_worker->ucp_data.ucp_worker); if (status != UCS_OK) { DOCA_LOG_ERR("Unable to create ucp worker"); result = DOCA_ERROR_DRIVER; goto ucp_cleanup; } ucc_worker->ucp_data.eps = kh_init(ep); if (ucc_worker->ucp_data.eps == NULL) { DOCA_LOG_ERR("Failed to init EP hashtable map"); result = DOCA_ERROR_DRIVER; goto worker_destroy; } ucc_worker->ucp_data.memh = kh_init(memh); if (ucc_worker->ucp_data.memh == NULL) { DOCA_LOG_ERR("Failed to init memh hashtable map"); result = DOCA_ERROR_DRIVER; goto eps_destroy; } ucc_worker->ucp_data.rkeys = kh_init(rkeys); if (ucc_worker->ucp_data.rkeys == NULL) { DOCA_LOG_ERR("Failed to init rkeys hashtable map"); result = DOCA_ERROR_DRIVER; goto memh_destroy; } ucc_worker->ids = kh_init(ctx_id); if (ucc_worker->ids == NULL) { DOCA_LOG_ERR("Failed to init ids hashtable map"); result = DOCA_ERROR_DRIVER; goto rkeys_destroy; } ucc_worker->super = ctx; ucc_worker->list_lock = 0; ucc_component_enabled = 1; ucs_list_head_init(&ucc_worker->completed_reqs); ctx->plugin_ctx = ucc_worker; DOCA_LOG_INFO("UCC worker open flow is done"); return DOCA_SUCCESS; rkeys_destroy: kh_destroy(rkeys, ucc_worker->ucp_data.rkeys); memh_destroy: kh_destroy(memh, ucc_worker->ucp_data.memh); eps_destroy: kh_destroy(ep, ucc_worker->ucp_data.eps); worker_destroy: ucp_worker_destroy(ucc_worker->ucp_data.ucp_worker); ucp_cleanup: ucp_cleanup(ucc_worker->ucp_data.ucp_context); queue_size_free: free((void *)queue_size); queue_tail_free: free((void *)queue_tail); queue_front_free: free((void *)queue_front); queue_free: for (j = 0; j < i; j++) free(ucc_worker->queue[j]); free(ucc_worker->queue); ucc_data_free: free(ucc_worker->ucc_data); ucc_free: free(ucc_worker); return result; } static void ucc_worker_join_and_free_threads() { uint64_t i; if (progress_thread) { for (i = 0; i < worker_ucc_opts.num_progress_threads; i++) { pthread_join(progress_thread[i], NULL); } free(progress_thread); progress_thread = NULL; } } /* * Close UCC worker plugin * * @worker_ctx [in]: DOCA UROM worker context */ static void urom_worker_ucc_close(struct urom_worker_ctx *worker_ctx) { struct urom_worker_ucc *ucc_worker = worker_ctx->plugin_ctx; uint64_t i; if (worker_ctx == NULL) return; ucc_component_enabled = 0; ucc_worker_join_and_free_threads(); /* Destroy hash tables */ kh_destroy(rkeys, ucc_worker->ucp_data.rkeys); kh_destroy(memh, ucc_worker->ucp_data.memh); kh_destroy(ep, ucc_worker->ucp_data.eps); kh_destroy(ctx_id, ucc_worker->ids); /* UCP cleanup */ ucp_worker_destroy(ucc_worker->ucp_data.ucp_worker); ucp_cleanup(ucc_worker->ucp_data.ucp_context); /* UCC worker resources destroy */ free((void *)queue_size); free((void *)queue_tail); free((void *)queue_front); free(ucc_worker->ucc_data); /* Queue elements destroy */ for (i = 0; i < worker_ucc_opts.num_progress_threads; i++) free(ucc_worker->queue[i]); free(ucc_worker->queue); /* UCC worker destroy */ free(ucc_worker); } /* * Unpacking UCC worker command * * @packed_cmd [in]: packed worker command * @packed_cmd_len [in]: packed worker command length * @cmd [out]: set unpacked UROM worker command * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_cmd_unpack(void *packed_cmd, size_t packed_cmd_len, struct urom_worker_cmd **cmd) { uint64_t extended_mem = 0; void *ptr; int is_count_64; int is_disp_64; size_t team_size; size_t disp_pack_size; size_t count_pack_size; ucc_coll_args_t *coll_args; struct urom_worker_ucc_cmd *ucc_cmd; if (packed_cmd_len < sizeof(struct urom_worker_ucc_cmd)) { DOCA_LOG_INFO("Invalid packed command length"); return DOCA_ERROR_INVALID_VALUE; } *cmd = packed_cmd; ptr = packed_cmd + ucs_offsetof(struct urom_worker_cmd, plugin_cmd) + sizeof(struct urom_worker_ucc_cmd); ucc_cmd = (struct urom_worker_ucc_cmd *)(*cmd)->plugin_cmd; switch (ucc_cmd->cmd_type) { case UROM_WORKER_CMD_UCC_LIB_CREATE: ucc_cmd->lib_create_cmd.params = ptr; extended_mem += sizeof(ucc_lib_params_t); break; case UROM_WORKER_CMD_UCC_COLL: coll_args = ptr; ucc_cmd->coll_cmd.coll_args = ptr; ptr += sizeof(ucc_coll_args_t); extended_mem += sizeof(ucc_coll_args_t); if (ucc_cmd->coll_cmd.work_buffer_size > 0) { ucc_cmd->coll_cmd.work_buffer = ptr; ptr += ucc_cmd->coll_cmd.work_buffer_size; extended_mem += ucc_cmd->coll_cmd.work_buffer_size; } if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV || coll_args->coll_type == UCC_COLL_TYPE_ALLGATHERV || coll_args->coll_type == UCC_COLL_TYPE_GATHERV || coll_args->coll_type == UCC_COLL_TYPE_REDUCE_SCATTERV || coll_args->coll_type == UCC_COLL_TYPE_SCATTERV) { team_size = ucc_cmd->coll_cmd.team_size; is_count_64 = ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && (coll_args->flags & UCC_COLL_ARGS_FLAG_COUNT_64BIT)); is_disp_64 = ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && (coll_args->flags & UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT)); count_pack_size = ((is_count_64) ? sizeof(uint64_t) : sizeof(uint32_t)) * team_size; disp_pack_size = ((is_disp_64) ? sizeof(uint64_t) : sizeof(uint32_t)) * team_size; coll_args->src.info_v.counts = ptr; ptr += count_pack_size; extended_mem += count_pack_size; coll_args->dst.info_v.counts = ptr; ptr += count_pack_size; extended_mem += count_pack_size; coll_args->src.info_v.displacements = ptr; ptr += disp_pack_size; extended_mem += disp_pack_size; coll_args->dst.info_v.displacements = ptr; ptr += disp_pack_size; extended_mem += disp_pack_size; } break; case UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL: ucc_cmd->pass_dc_create_cmd.ucp_addr = ptr; extended_mem += ucc_cmd->pass_dc_create_cmd.addr_len; break; default: DOCA_LOG_ERR("Invalid UCC cmd: %u", ucc_cmd->cmd_type); break; } if ((*cmd)->len != extended_mem + sizeof(struct urom_worker_ucc_cmd)) { DOCA_LOG_ERR("Invalid UCC command length"); return DOCA_ERROR_INVALID_VALUE; } return DOCA_SUCCESS; } /* * UCC worker safe push notification function * * @ucc_worker [in]: UCC worker context * @nd [in]: UROM worker notification descriptor */ static void ucc_worker_safe_push_notification(struct urom_worker_ucc *ucc_worker, struct urom_worker_notif_desc *nd) { uint64_t lvalue = 0; lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); while (lvalue != 0) lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); ucs_list_add_tail(&ucc_worker->completed_reqs, &nd->entry); lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 1, 0); } /* * UCC worker host destination remove * * @ucc_worker [in]: UCC worker context * @dest_id [in]: Host client dest id */ static void worker_ucc_dest_remove(struct urom_worker_ucc *ucc_worker, uint64_t dest_id) { khint_t k; k = kh_get(ctx_id, ucc_worker->ids, dest_id); if (k == kh_end(ucc_worker->ids)) { DOCA_LOG_ERR("Destination id - %lu does not exist", dest_id); return; } kh_del(ctx_id, ucc_worker->ids, k); ucc_worker->ctx_id--; } /* * UCC worker host destinations lookup function * * @ucc_worker [in]: UCC worker context * @dest_id [in]: Host client dest id * @ctx_id [out]: Host client context id * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t worker_ucc_dest_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest_id, uint64_t *ctx_id) { int ret; khint_t k; k = kh_get(ctx_id, ucc_worker->ids, dest_id); if (k != kh_end(ucc_worker->ids)) { *ctx_id = kh_value(ucc_worker->ids, k); return DOCA_SUCCESS; } *ctx_id = ucc_worker->ctx_id; k = kh_put(ctx_id, ucc_worker->ids, dest_id, &ret); if (ret < 0) { DOCA_LOG_ERR("Failed to put new context id"); return DOCA_ERROR_DRIVER; } ucc_worker->ctx_id++; kh_value(ucc_worker->ids, k) = *ctx_id; DOCA_LOG_DBG("UCC worker added connection %ld", *ctx_id); return DOCA_SUCCESS; } /* * Handle UCC library create command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_lib_create(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { struct urom_worker_cmd *cmd = (struct urom_worker_cmd *) &cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *) cmd->plugin_cmd; uint64_t ctx_id; uint64_t i; doca_error_t result; ucc_status_t ucc_status; ucc_lib_config_h lib_config; ucc_lib_params_t *lib_params; struct urom_worker_notify *notif; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; /* Prepare notification */ nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) return DOCA_ERROR_NO_MEMORY; nd->dest_id = cmd_desc->dest_id; notif = (struct urom_worker_notify *)&nd->worker_notif; notif->type = cmd->type; notif->urom_context = cmd->urom_context; notif->len = sizeof(*ucc_notif); notif->status = DOCA_SUCCESS; ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_LIB_CREATE_COMPLETE; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; lib_params = ucc_cmd->lib_create_cmd.params; lib_params->mask |= UCC_LIB_PARAM_FIELD_THREAD_MODE; lib_params->thread_mode = UCC_THREAD_MULTIPLE; result = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); if (result != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup command destination"); goto fail; } ucc_worker->nr_connections++; if (ucc_worker->nr_connections > worker_ucc_opts.ppw) { DOCA_LOG_ERR("Too many processes connected to a single worker"); result = DOCA_ERROR_FULL; goto dest_remove; } if (UCC_OK != ucc_lib_config_read(NULL, NULL, &lib_config)) { DOCA_LOG_ERR("Failed to read UCC lib config"); result = DOCA_ERROR_DRIVER; goto reduce_conn; } for (i = 0; i < worker_ucc_opts.tpp; i++) { ucc_status = ucc_init(lib_params, lib_config, &ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].ucc_lib); if (ucc_status != UCC_OK) { DOCA_LOG_ERR("Failed to init UCC lib"); result = DOCA_ERROR_DRIVER; goto reduce_conn; } } ucc_lib_config_release(lib_config); DOCA_LOG_DBG("Created UCC lib successfully"); notif->status = DOCA_SUCCESS; ucc_worker_safe_push_notification(ucc_worker, nd); return notif->status; reduce_conn: ucc_worker->nr_connections--; dest_remove: worker_ucc_dest_remove(ucc_worker, cmd_desc->dest_id); fail: DOCA_LOG_ERR("Failed to create UCC lib"); notif->status = result; ucc_worker_safe_push_notification(ucc_worker, nd); return result; } /* * UCC library destroy * * @ucc_worker [in]: UCC worker context * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t ucc_worker_lib_destroy(struct urom_worker_ucc *ucc_worker) { doca_error_t result = DOCA_SUCCESS; uint64_t j, k; int64_t i; ucc_status_t status; ucc_component_enabled = 0; ucc_worker_join_and_free_threads(); for (j = 0; j < ucc_worker->nr_connections; j++) { for (k = 0; k < worker_ucc_opts.tpp; k++) { struct ucc_data *ucc_ptr = &ucc_worker->ucc_data[j*worker_ucc_opts.tpp + k]; for (i = 0; i < ucc_ptr->n_teams; i++) { if (!ucc_ptr->ucc_team[i]) { continue; } status = ucc_team_destroy(ucc_ptr->ucc_team[i]); if (status != UCC_OK) { DOCA_LOG_ERR("Failed to destroy UCC team of " "data index %lu and team index %ld", j, i); result = DOCA_ERROR_DRIVER; } free(ucc_ptr->pSync); } if (ucc_ptr->ucc_context) { status = ucc_context_destroy(ucc_ptr->ucc_context); if (status != UCC_OK) { DOCA_LOG_ERR("Failed to destroy UCC context of " "UCC data index %lu", j); result = DOCA_ERROR_DRIVER; } ucc_ptr->ucc_context = NULL; } if (ucc_ptr->ucc_lib) { status = ucc_finalize(ucc_ptr->ucc_lib); if (status != UCC_OK) { DOCA_LOG_ERR("Failed to finalize UCC lib " "of UCC data index %lu", j); result = DOCA_ERROR_DRIVER; } } } } return result; } /* * Handle UCC library destroy command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_lib_destroy(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { struct urom_worker_cmd *cmd = (struct urom_worker_cmd *) &cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *) cmd->plugin_cmd; struct urom_worker_notify *notif; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; /* Prepare notification */ nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) return DOCA_ERROR_NO_MEMORY; nd->dest_id = cmd_desc->dest_id; notif = (struct urom_worker_notify *)&nd->worker_notif; notif->type = cmd->type; notif->urom_context = cmd->urom_context; notif->len = sizeof(*ucc_notif); ucc_notif = (struct urom_worker_notify_ucc *)notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_LIB_DESTROY_COMPLETE; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; notif->status = ucc_worker_lib_destroy(ucc_worker); ucc_worker_safe_push_notification(ucc_worker, nd); return notif->status; } /* * Thread progress handles queue collective element * * @qe [in]: UCC thread queue element * @ucc_worker [in]: UCC worker context * @thread_id [in]: UCC thread id * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t handle_progress_thread_coll_element(struct ucc_queue_element *qe, struct urom_worker_ucc *ucc_worker, int thread_id) { int64_t lvalue = 0; ucc_status_t ucc_status = UCC_OK; doca_error_t status = DOCA_SUCCESS; ucc_status_t tmp_status; struct ucc_queue_element *qe_back; struct urom_worker_notify_ucc *ucc_notif; if (!qe->posted) { ucc_status = ucc_collective_post(qe->coll_req); if (UCC_OK != ucc_status) { DOCA_LOG_ERR("Failed to post UCC collective: %s", ucc_status_string(ucc_status)); status = DOCA_ERROR_DRIVER; goto exit; } qe->posted = 1; } ucc_status = ucc_collective_test(qe->coll_req); if (ucc_status == UCC_INPROGRESS) { ucc_context_progress(ucc_worker->ucc_data[qe->ctx_id].ucc_context); lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); while (lvalue != 0) lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); status = find_qe_slot(qe->ctx_id, ucc_worker, &qe_back); lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find queue slot for team creation"); ucc_status = UCC_ERR_NO_RESOURCE; goto exit; } *qe_back = *qe; qe->in_use = 0; queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; return DOCA_ERROR_IN_PROGRESS; } else if (ucc_status == UCC_OK) { if (qe->barrier) { pthread_barrier_wait(qe->barrier); if (qe->nd != NULL) { pthread_barrier_destroy(qe->barrier); free(qe->barrier); qe->barrier = NULL; } } if (qe->key_duplicate_per_rank) { free(qe->key_duplicate_per_rank); qe->key_duplicate_per_rank = NULL; } if (qe->old_dest) { DOCA_LOG_DBG("Putting data back to host %p with size %lu", qe->old_dest, qe->data_size); if (qe->dest_packed_key != NULL) { status = ucc_rma_put_host( ucc_worker->ucc_data[qe->ctx_id].local_work_buffer + qe->data_size, qe->old_dest, qe->data_size, qe->ctx_id, qe->dest_packed_key, ucc_worker); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find queue slot for team creation"); goto exit; } } else { status = ucc_rma_put( ucc_worker->ucc_data[qe->ctx_id].local_work_buffer + qe->data_size, qe->old_dest, qe->data_size, MAX_HOST_DEST_ID, qe->myrank, qe->ctx_id, ucc_worker); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find queue slot for team creation"); goto exit; } } } if (qe->gwbi != NULL && qe->nd != NULL) { free(qe->gwbi); } } else { DOCA_LOG_ERR("ucc_collective_test() returned failure (%d)", ucc_status); status = DOCA_ERROR_DRIVER; goto exit; } status = DOCA_SUCCESS; tmp_status = ucc_collective_test(qe->coll_req); if (tmp_status != UCC_OK) { ucc_status = (ucc_status == UCC_OK) ? tmp_status : ucc_status; status = DOCA_ERROR_DRIVER; } tmp_status = ucc_collective_finalize(qe->coll_req); if (tmp_status != UCC_OK) { ucc_status = (ucc_status == UCC_OK) ? tmp_status : ucc_status; status = DOCA_ERROR_DRIVER; } exit: if (qe->nd != NULL) { ucc_notif = (struct urom_worker_notify_ucc *) qe->nd->worker_notif.plugin_notif; ucc_notif->coll_nqe.status = ucc_status; qe->nd->worker_notif.status = status; ucc_worker_safe_push_notification(ucc_worker, qe->nd); } queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; ucs_atomic_add64(&queue_size[thread_id], -1); qe->in_use = 0; return status; } /* * Thread progress handles queue team element * * @qe [in]: UCC thread queue element * @ucc_worker [in]: UCC worker context * @thread_id [in]: UCC thread id * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t handle_progress_thread_team_element(struct ucc_queue_element *qe, struct urom_worker_ucc *ucc_worker, int thread_id) { struct urom_worker_notify_ucc *ucc_notif = NULL; int64_t lvalue = 0; ucc_status_t ucc_status = UCC_OK; doca_error_t status = DOCA_SUCCESS; struct ucc_queue_element *qe_back; if(qe->nd != NULL) { ucc_notif = (struct urom_worker_notify_ucc *) qe->nd->worker_notif.plugin_notif; } ucc_status = ucc_team_create_test( ucc_worker->ucc_data[qe->ctx_id].ucc_team[qe->team_id]); if (ucc_status == UCC_INPROGRESS) { ucc_status = ucc_context_progress( ucc_worker->ucc_data[qe->ctx_id].ucc_context); lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); while (lvalue != 0) lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); status = find_qe_slot(qe->ctx_id, ucc_worker, &qe_back); lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); if (status != DOCA_SUCCESS) goto exit; *qe_back = *qe; queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; qe->in_use = 0; return DOCA_ERROR_IN_PROGRESS; } else if (ucc_status != UCC_OK) { DOCA_LOG_ERR("UCC team create test failed (%d) on team %ld for ctx %ld", ucc_status, qe->team_id, qe->ctx_id); if (ucc_notif) ucc_notif->team_create_nqe.team = NULL; status = DOCA_ERROR_DRIVER; } else { if (qe->barrier) { pthread_barrier_wait(qe->barrier); if (qe->nd != NULL) { pthread_barrier_destroy(qe->barrier); free(qe->barrier); qe->barrier = NULL; } } DOCA_LOG_INFO("Finished team creation (%ld:%ld)", qe->ctx_id, qe->team_id); if (ucc_notif) { ucc_notif->team_create_nqe.team = ucc_worker->ucc_data[qe->ctx_id].ucc_team[qe->team_id]; } status = DOCA_SUCCESS; } exit: free(qe->coll_ctx); if (qe->nd != NULL) { qe->nd->worker_notif.status = status; ucc_worker_safe_push_notification(ucc_worker, qe->nd); } queue_front[thread_id] = (queue_front[thread_id] + 1) % worker_ucc_opts.list_size; ucs_atomic_add64(&queue_size[thread_id], -1); qe->in_use = 0; return status; } /* * Progress context thread main function * * @arg [in]: UCC worker arg * @return: NULL (dummy return because of pthread requirement) */ static void *urom_worker_ucc_progress_thread(void *arg) { struct thread_args *targs = (struct thread_args *)arg; int thread_id = targs->thread_id; struct urom_worker_ucc *ucc_worker = targs->ucc_worker; doca_error_t status = DOCA_SUCCESS; int i; int front; int size; struct ucc_queue_element *qe; while (ucc_component_enabled) { size = queue_size[thread_id]; for (i = 0; i < size; i++) { front = queue_front[thread_id]; qe = &ucc_worker->queue[thread_id][front]; if (qe->in_use != 1) { DOCA_LOG_WARN("Found queue element in " "queue and marked not in use"); continue; } if (qe->type == UCC_WORKER_QUEUE_ELEMENT_TYPE_TEAM_CREATE) { status = handle_progress_thread_team_element(qe, ucc_worker, thread_id); if (status == DOCA_ERROR_IN_PROGRESS) continue; if (status != DOCA_SUCCESS) goto exit; } else if (qe->type == UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE) { status = handle_progress_thread_coll_element(qe, ucc_worker, thread_id); if (status == DOCA_ERROR_IN_PROGRESS) continue; if (status != DOCA_SUCCESS) goto exit; } else DOCA_LOG_ERR("Unknown queue element type"); } sched_yield(); } exit: pthread_exit(NULL); } /* * UCC OOB allgather free * * @req [in]: allgather request data * @return: UCC_OK on success and UCC_ERR otherwise */ static ucc_status_t urom_worker_ucc_oob_allgather_free(void *req) { free(req); return UCC_OK; } /* * UCC oob allgather function * * @sbuf [in]: local buffer to send to other processes * @rbuf [in]: global buffer to includes other processes source buffer * @msglen [in]: source buffer length * @oob_coll_ctx [in]: collection info * @req [out]: set allgather request data * @return: UCC_OK on success and UCC_ERR otherwise */ static ucc_status_t urom_worker_ucc_oob_allgather(void *sbuf, void *rbuf, size_t msglen, void *oob_coll_ctx, void **req) { struct coll_ctx *ctx = (struct coll_ctx *) oob_coll_ctx; char *recv_buf; int index; int size; int i; struct oob_allgather_req *oob_req; size = ctx->size; index = ctx->index; oob_req = malloc(sizeof(*oob_req)); if (oob_req == NULL) { DOCA_LOG_ERR("Failed to allocate OOB UCC request"); return UCC_ERR_NO_MEMORY; } oob_req->sbuf = sbuf; oob_req->rbuf = rbuf; oob_req->msglen = msglen; oob_req->oob_coll_ctx = oob_coll_ctx; oob_req->iter = 0; oob_req->status = calloc(ctx->size * 2, sizeof(int)); *req = oob_req; for (i = 0; i < size; i++) { recv_buf = (char *)rbuf + i * msglen; ucc_recv_nb(recv_buf, msglen, i, ctx->ucc_worker, &oob_req->status[i]); } for (i = 0; i < size; i++) { ucc_send_nb(sbuf, msglen, index, i, ctx->ucc_worker, &oob_req->status[i + size]); } return UCC_OK; } /* * UCC oob allgather test function * * @req [in]: UCC allgather request * @return: UCC_OK on success and UCC_ERR otherwise */ static ucc_status_t urom_worker_ucc_oob_allgather_test(void *req) { int nr_probes = 5; struct coll_ctx *ctx; struct oob_allgather_req *oob_req; int i; int probe_count; int nr_done; int size; oob_req = (struct oob_allgather_req *)req; ctx = (struct coll_ctx *)oob_req->oob_coll_ctx; size = ctx->size; for (probe_count = 0; probe_count < nr_probes; probe_count++) { nr_done = 0; for (i = 0; i < size * 2; i++) { if (oob_req->status[i] != 1 && ctx->ucc_worker->ucp_data.ucp_worker != NULL) { ucp_worker_progress(ctx->ucc_worker->ucp_data.ucp_worker); } else { ++nr_done; } } if (nr_done == size * 2) return UCC_OK; } return UCC_INPROGRESS; } /* * Handle UCC context creation of progress threads * * @arg [in]: UCC worker context argument * @return: NULL (dummy return because of pthread requirement) */ static void *urom_worker_ucc_ctx_progress_thread(void *arg) { struct ctx_thread_args *args = (struct ctx_thread_args *)arg; ucc_mem_map_t **maps = NULL; size_t len = args->len; int64_t size = args->size; int64_t start = args->start; int64_t stride = args->stride; int64_t myrank = args->myrank; uint64_t dest_id = args->dest_id; struct urom_worker_ucc *ucc_worker = args->ucc_worker; ucc_context_params_t ctx_params = {0}; struct urom_worker_notif_desc *nd; struct urom_worker_notify *notif; struct urom_worker_notify_ucc *ucc_notif; struct thread_args *targs; uint64_t n_threads, i, j; int ret; uint64_t ctx_id; char str_buf[256]; ucc_status_t ucc_status; doca_error_t status; struct coll_ctx **coll_ctx; ucc_context_config_h ctx_config; nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) { status = DOCA_ERROR_NO_MEMORY; goto exit; } nd->dest_id = args->dest_id; notif = (struct urom_worker_notify *)&nd->worker_notif; notif->type = args->notif_type; notif->len = sizeof(*ucc_notif); notif->urom_context = args->urom_context; ucc_notif = (struct urom_worker_notify_ucc *) notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE; ucc_notif->dpu_worker_id = args->myrank; status = worker_ucc_dest_lookup(ucc_worker, dest_id, &ctx_id); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup command destination"); goto fail; } maps = (ucc_mem_map_t **) calloc(worker_ucc_opts.tpp, sizeof(ucc_mem_map_t *)); coll_ctx = (struct coll_ctx **) calloc(worker_ucc_opts.tpp, sizeof(struct coll_ctx *)); for (i = 0; i < worker_ucc_opts.tpp; i++) { uint64_t thread_ctx_id = ctx_id * worker_ucc_opts.tpp + i; if (ucc_worker->ucc_data[thread_ctx_id].ucc_lib == NULL) { DOCA_LOG_ERR("Attempting to create UCC context " "without first initializing a UCC lib"); status = DOCA_ERROR_BAD_STATE; goto fail; } if (ucc_context_config_read(ucc_worker->ucc_data[thread_ctx_id].ucc_lib, NULL, &ctx_config) != UCC_OK) { DOCA_LOG_ERR("Failed to read UCC context config"); status = DOCA_ERROR_DRIVER; goto fail; } /* Set to sliding window */ if (UCC_OK != ucc_context_config_modify( ctx_config, "tl/ucp", "TUNE", "allreduce:0-inf:@sliding_window")) { DOCA_LOG_ERR("Failed to modify TL_UCP_TUNE UCC lib config"); status = DOCA_ERROR_DRIVER; goto cfg_release; } /* Set estimated num of eps */ sprintf(str_buf, "%ld", size); ucc_status = ucc_context_config_modify(ctx_config, NULL, "ESTIMATED_NUM_EPS", str_buf); if (ucc_status != UCC_OK) { DOCA_LOG_ERR("UCC context config modify " "failed for estimated_num_eps"); status = DOCA_ERROR_DRIVER; goto cfg_release; } ucc_worker->ucc_data[thread_ctx_id].local_work_buffer = calloc(1, len * 2); if (ucc_worker->ucc_data[thread_ctx_id].local_work_buffer == NULL) { DOCA_LOG_ERR("Failed to allocate local work buffer"); status = DOCA_ERROR_NO_MEMORY; goto cfg_release; } ucc_worker->ucc_data[thread_ctx_id].pSync = calloc(worker_ucc_opts.num_psync, sizeof(long)); if (ucc_worker->ucc_data[thread_ctx_id].pSync == NULL) { DOCA_LOG_ERR("Failed to pSync array"); status = DOCA_ERROR_NO_MEMORY; goto buf_free; } ucc_worker->ucc_data[thread_ctx_id].len = len * 2; maps[i] = (ucc_mem_map_t *)calloc(3, sizeof(ucc_mem_map_t)); if (maps[i] == NULL) { DOCA_LOG_ERR("Failed to allocate UCC memory map array"); status = DOCA_ERROR_NO_MEMORY; goto psync_free; } maps[i][0].address = ucc_worker->ucc_data[thread_ctx_id].local_work_buffer; maps[i][0].len = len * 2; maps[i][1].address = ucc_worker->ucc_data[thread_ctx_id].pSync; maps[i][1].len = worker_ucc_opts.num_psync * sizeof(long); coll_ctx[i] = (struct coll_ctx *)malloc(sizeof(struct coll_ctx)); if (coll_ctx[i] == NULL) { DOCA_LOG_ERR("Failed to allocate UCC worker coll context"); status = DOCA_ERROR_NO_MEMORY; goto maps_free; } if (stride <= 0) {/* This is an array of ids */ coll_ctx[i]->pids = (int64_t *)start; } else { coll_ctx[i]->start = start; } coll_ctx[i]->stride = stride; coll_ctx[i]->size = size; coll_ctx[i]->index = myrank; coll_ctx[i]->ucc_worker = ucc_worker; ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_OOB | UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS; ctx_params.oob.allgather = urom_worker_ucc_oob_allgather; ctx_params.oob.req_test = urom_worker_ucc_oob_allgather_test; ctx_params.oob.req_free = urom_worker_ucc_oob_allgather_free; ctx_params.oob.coll_info = (void *)coll_ctx[i]; ctx_params.oob.n_oob_eps = size; ctx_params.oob.oob_ep = myrank; ctx_params.mem_params.segments = maps[i]; ctx_params.mem_params.n_segments = 2; ucc_status = ucc_context_create( ucc_worker->ucc_data[thread_ctx_id].ucc_lib, &ctx_params, ctx_config, &ucc_worker->ucc_data[thread_ctx_id].ucc_context); if (ucc_status != UCC_OK) { DOCA_LOG_ERR("Failed to create UCC context"); status = DOCA_ERROR_DRIVER; goto coll_free; } ucc_context_config_release(ctx_config); } if (ctx_id == 0) { n_threads = worker_ucc_opts.num_progress_threads; targs = calloc(n_threads, sizeof(*targs)); if (targs == NULL) { DOCA_LOG_ERR("Failed to create threads args"); status = DOCA_ERROR_NO_MEMORY; goto context_destroy; } progress_thread = calloc(n_threads, sizeof(*progress_thread)); if (progress_thread == NULL) { DOCA_LOG_ERR("Failed to create threads args"); status = DOCA_ERROR_NO_MEMORY; goto targs_free; } DOCA_LOG_DBG("Creating [%ld] progress %lu threads", myrank, n_threads); for (i = 0; i < n_threads; i++) { targs[i].thread_id = i; targs[i].ucc_worker = ucc_worker; ret = pthread_create(&progress_thread[i], NULL, urom_worker_ucc_progress_thread, (void *)&targs[i]); if (ret != 0) { DOCA_LOG_ERR("Failed to create progress thread"); status = DOCA_ERROR_IO_FAILED; goto threads_free; } } } status = DOCA_SUCCESS; ucc_notif->context_create_nqe.context = ucc_worker->ucc_data[ctx_id].ucc_context; DOCA_LOG_DBG("UCC context created, ctx_id %lu, context %p", ctx_id, ucc_worker->ucc_data[ctx_id].ucc_context); goto exit; threads_free: for (j = 0; j < i; j++) { pthread_cancel(progress_thread[j]); } free(progress_thread); targs_free: free(targs); context_destroy: for(i = 0; i < worker_ucc_opts.tpp; i++) { if(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].ucc_context) { ucc_context_destroy( ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].ucc_context); } } coll_free: for(i = 0; i < worker_ucc_opts.tpp; i++) { if(coll_ctx[i]) { free(coll_ctx[i]); } } free(coll_ctx); maps_free: for(i = 0; i < worker_ucc_opts.tpp; i++) { if(maps[i]) { free(maps[i]); } } free(maps); psync_free: for(i = 0; i < worker_ucc_opts.tpp; i++) { if(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].pSync) { free(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i].pSync); } } buf_free: for(i = 0; i < worker_ucc_opts.tpp; i++) { if(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i]. local_work_buffer) { free(ucc_worker->ucc_data[ctx_id*worker_ucc_opts.tpp + i]. local_work_buffer); } } cfg_release: ucc_context_config_release(ctx_config); fail: exit: nd->worker_notif.status = status; ucc_worker_safe_push_notification(ucc_worker, nd); free(args); pthread_exit(NULL); } /* * Handle UCC context create command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_context_create(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { struct urom_worker_cmd *cmd = (struct urom_worker_cmd *) &cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *) cmd->plugin_cmd; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; int ret; struct ctx_thread_args *args; args = calloc(1, sizeof(*args)); if (args == NULL) { return DOCA_ERROR_NO_MEMORY; } args->notif_type = cmd->type; args->urom_context = cmd->urom_context; args->start = ucc_cmd->context_create_cmd.start; args->stride = ucc_cmd->context_create_cmd.stride; args->size = ucc_cmd->context_create_cmd.size; args->myrank = ucc_cmd->dpu_worker_id; args->base_va = ucc_cmd->context_create_cmd.base_va; args->len = ucc_cmd->context_create_cmd.len; args->dest_id = cmd_desc->dest_id; args->ucc_worker = ucc_worker; ret = pthread_create(&context_progress_thread, NULL, urom_worker_ucc_ctx_progress_thread, (void *)args); if (ret != 0) { nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) { return DOCA_ERROR_NO_MEMORY; } nd->dest_id = cmd_desc->dest_id; nd->worker_notif.status = DOCA_ERROR_IO_FAILED; nd->worker_notif.type = cmd->type; nd->worker_notif.len = sizeof(*ucc_notif); nd->worker_notif.urom_context = cmd->urom_context; ucc_notif = (struct urom_worker_notify_ucc *) nd->worker_notif.plugin_notif; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_CONTEXT_CREATE_COMPLETE; ucc_worker_safe_push_notification(ucc_worker, nd); return DOCA_ERROR_IO_FAILED; } return DOCA_SUCCESS; } /* * Handle UCC context destroy command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_context_destroy(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { struct urom_worker_cmd *cmd = (struct urom_worker_cmd *) &cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *) cmd->plugin_cmd; uint64_t ctx_id, i; doca_error_t status; struct urom_worker_notify *notif; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; /* Prepare notification */ nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) { return DOCA_ERROR_NO_MEMORY; } nd->dest_id = cmd_desc->dest_id; notif = (struct urom_worker_notify *)&nd->worker_notif; notif->type = cmd->type; notif->urom_context = cmd->urom_context; notif->len = sizeof(*ucc_notif); ucc_notif = (struct urom_worker_notify_ucc *) notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_CONTEXT_DESTROY_COMPLETE; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup command destination"); goto exit; } for (i = 0; i < worker_ucc_opts.tpp; i++) { uint64_t thread_ctx_id = ctx_id * worker_ucc_opts.tpp + i; if (ucc_worker->ucc_data[thread_ctx_id].ucc_context) { if (ucc_context_destroy( ucc_worker->ucc_data[thread_ctx_id].ucc_context) != UCC_OK) { DOCA_LOG_ERR("Failed to destroy UCC context"); status = DOCA_ERROR_DRIVER; goto exit; } ucc_worker->ucc_data[thread_ctx_id].ucc_context = NULL; } } status = DOCA_SUCCESS; exit: notif->status = status; ucc_worker_safe_push_notification(ucc_worker, nd); return status; } /* * Handle UCC team command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_team_create(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { struct urom_worker_cmd *cmd = (struct urom_worker_cmd *) &cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *) cmd->plugin_cmd; size_t curr_team = 0; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; uint64_t ctx_id, i; ucc_ep_map_t map; doca_error_t status; ucc_status_t ucc_status; struct coll_ctx *coll_ctx; struct ucc_queue_element *qe; ucc_team_params_t team_params; struct urom_worker_notify *notif; pthread_barrier_t *barrier; uint64_t lvalue; /* Prepare notification */ nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) { return DOCA_ERROR_NO_MEMORY; } nd->dest_id = cmd_desc->dest_id; notif = (struct urom_worker_notify *)&nd->worker_notif; notif->type = cmd->type; notif->urom_context = cmd->urom_context; notif->len = sizeof(*ucc_notif); ucc_notif = (struct urom_worker_notify_ucc *) notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_TEAM_CREATE_COMPLETE; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup command destination"); goto exit; } barrier = malloc(sizeof(pthread_barrier_t)); pthread_barrier_init(barrier, NULL, worker_ucc_opts.tpp); for (i = 0; i < worker_ucc_opts.tpp; i++) { uint64_t thread_ctx_id = ctx_id * worker_ucc_opts.tpp + i; curr_team = ucc_worker->ucc_data[thread_ctx_id].n_teams; if (ucc_worker->ucc_data[thread_ctx_id].ucc_context == NULL || ucc_cmd->team_create_cmd.context_h != ucc_worker->ucc_data[ctx_id].ucc_context) { DOCA_LOG_ERR("Attempting to create UCC " "team over non-existent context"); status = DOCA_ERROR_INVALID_VALUE; goto exit; } if (ucc_cmd->team_create_cmd.stride <= 0) { map.type = UCC_EP_MAP_ARRAY; map.ep_num = ucc_cmd->team_create_cmd.size; map.array.map = (void *)ucc_cmd->team_create_cmd.start; map.array.elem_size = 8; } else { map.type = UCC_EP_MAP_STRIDED; map.ep_num = ucc_cmd->team_create_cmd.size; map.strided.start = ucc_cmd->team_create_cmd.start; map.strided.stride = ucc_cmd->team_create_cmd.stride; } team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_TEAM_SIZE | UCC_TEAM_PARAM_FIELD_EP_MAP | UCC_TEAM_PARAM_FIELD_EP_RANGE; team_params.ep = ucc_cmd->dpu_worker_id; team_params.ep_map = map; team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; team_params.team_size = ucc_cmd->team_create_cmd.size; coll_ctx = (struct coll_ctx *) malloc(sizeof(*coll_ctx)); if (coll_ctx == NULL) { DOCA_LOG_ERR("Failed to allocate collective context"); status = DOCA_ERROR_NO_MEMORY; goto exit; } coll_ctx->start = ucc_cmd->team_create_cmd.start; coll_ctx->stride = ucc_cmd->team_create_cmd.stride; coll_ctx->size = ucc_cmd->team_create_cmd.size; coll_ctx->index = ucc_cmd->dpu_worker_id; coll_ctx->ucc_worker = ucc_worker; if (ucc_worker->ucc_data[thread_ctx_id].ucc_team == NULL) { ucc_worker->ucc_data[thread_ctx_id].ucc_team = malloc(sizeof(ucc_worker->ucc_data[thread_ctx_id].ucc_team)); if (ucc_worker->ucc_data[thread_ctx_id].ucc_team == NULL) { status = DOCA_ERROR_NO_MEMORY; goto coll_free; } } ucc_status = ucc_team_create_post( &ucc_worker->ucc_data[thread_ctx_id].ucc_context, 1, &team_params, &ucc_worker->ucc_data[thread_ctx_id].ucc_team[curr_team]); if (ucc_status != UCC_OK) { DOCA_LOG_ERR("ucc_team_create_post() failed"); status = DOCA_ERROR_DRIVER; goto team_free; } ucc_worker->ucc_data[thread_ctx_id].n_teams++; lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); while (lvalue != 0) { lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); } status = find_qe_slot(thread_ctx_id, ucc_worker, &qe); lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find queue slot for team creation"); goto team_free; } qe->type = UCC_WORKER_QUEUE_ELEMENT_TYPE_TEAM_CREATE; qe->coll_ctx = coll_ctx; qe->dest_id = cmd_desc->dest_id; qe->ctx_id = thread_ctx_id; qe->team_id = curr_team; qe->myrank = ucc_cmd->dpu_worker_id; qe->in_use = 1; qe->barrier = barrier; if (i == 0) { qe->nd = nd; } else { qe->nd = NULL; } ucs_atomic_add64( &queue_size[thread_ctx_id % worker_ucc_opts.num_progress_threads], 1); continue; team_free: free(ucc_worker->ucc_data[thread_ctx_id].ucc_team); coll_free: free(coll_ctx); goto exit; } return DOCA_SUCCESS; exit: notif->status = status; ucc_worker_safe_push_notification(ucc_worker, nd); return status; } size_t urom_worker_get_dt_size(ucc_datatype_t dt) { size_t size_mod = 8; switch (dt) { case UCC_DT_INT8: case UCC_DT_UINT8: size_mod = sizeof(char); break; case UCC_DT_INT32: case UCC_DT_UINT32: case UCC_DT_FLOAT32: size_mod = sizeof(int); break; case UCC_DT_INT64: case UCC_DT_UINT64: case UCC_DT_FLOAT64: size_mod = sizeof(uint64_t); break; case UCC_DT_INT128: case UCC_DT_UINT128: case UCC_DT_FLOAT128: size_mod = sizeof(__int128_t); break; default: break; } return size_mod; } static doca_error_t post_nthreads_colls( uint64_t ctx_id, struct urom_worker_ucc *ucc_worker, ucc_coll_args_t *coll_args, ucc_team_h ucc_team, uint64_t myrank, int in_place, ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi, struct urom_worker_notif_desc *nd, struct urom_worker_cmd_desc *cmd_desc, struct urom_worker_notify *notif, ucc_worker_key_buf *key_duplicate_per_rank) { doca_error_t status = DOCA_SUCCESS; pthread_barrier_t *barrier = NULL; int64_t team_idx = 0; size_t threads = worker_ucc_opts.tpp; size_t src_count = coll_args->src.info.count; size_t dst_count = coll_args->dst.info.count; size_t src_thread_count = src_count / threads; size_t dst_thread_count = dst_count / threads; size_t src_thread_size = src_thread_count * urom_worker_get_dt_size( coll_args->src.info.datatype); size_t dst_thread_size = dst_thread_count * urom_worker_get_dt_size( coll_args->dst.info.datatype); void *src_buf = coll_args->src.info.buffer; void *dst_buf = coll_args->dst.info.buffer; ucc_coll_req_h coll_req; struct ucc_queue_element *qe; ucc_status_t ucc_status; size_t i; uint64_t lvalue; int64_t j; coll_args->mask |= UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; coll_args->flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; coll_args->global_work_buffer = gwbi; barrier = malloc(sizeof(pthread_barrier_t)); pthread_barrier_init(barrier, NULL, worker_ucc_opts.tpp); for(i = 0; i < threads; i++) { uint64_t thread_ctx_id = ctx_id*worker_ucc_opts.tpp + i; gwbi = malloc(sizeof(ucc_tl_ucp_allreduce_sw_global_work_buf_info_t)); if (gwbi == NULL) { DOCA_LOG_ERR("Failed to initialize UCC collective: " "Couldnt malloc global work buffer"); status = DOCA_ERROR_DRIVER; goto fail; } gwbi->packed_src_memh = key_duplicate_per_rank[i].rkeys; gwbi->packed_dst_memh = key_duplicate_per_rank[i].rkeys + key_duplicate_per_rank[i].src_len; coll_args->global_work_buffer = gwbi; if(!in_place) { coll_args->src.info.count = src_thread_count; } coll_args->dst.info.count = dst_thread_count; if(!in_place) { coll_args->src.info.buffer = src_buf + i * src_thread_size; } coll_args->dst.info.buffer = dst_buf + i * dst_thread_size; if (i == threads - 1) { if(!in_place) { coll_args->src.info.count += src_count % threads; } coll_args->dst.info.count += dst_count % threads; } if (i == 0) { // the threads made these teams at the same time, so their index is the same in their arrays // TODO: is there a better way to associate these teams with each other? maybe use a map? for (j = 0; j < ucc_worker->ucc_data[thread_ctx_id].n_teams; j++) { if (ucc_worker->ucc_data[thread_ctx_id].ucc_team[j] == ucc_team) { team_idx = j; break; } } } ucc_status = ucc_collective_init(coll_args, &coll_req, ucc_worker->ucc_data[thread_ctx_id]. ucc_team[team_idx]); if (UCC_OK != ucc_status) { DOCA_LOG_ERR("Failed to initialize UCC collective: %s", ucc_status_string(ucc_status)); status = DOCA_ERROR_DRIVER; goto fail; } if (thread_ctx_id >= worker_ucc_opts.num_progress_threads) { DOCA_LOG_ERR("Warning--possible deadlock: multiple threads posting" "to the same queue, and the qe is going to barrier. " "Ensure tpp < num progress threads to avoid this\n"); } lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); while (lvalue != 0) lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); status = find_qe_slot(thread_ctx_id, ucc_worker, &qe); lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find queue slot for team creation"); goto req_destroy; } qe->type = UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE; qe->coll_req = coll_req; qe->myrank = myrank; qe->dest_id = cmd_desc->dest_id; qe->old_dest = NULL; qe->data_size = 0; qe->gwbi = gwbi; qe->dest_packed_key = NULL; qe->ctx_id = thread_ctx_id; qe->in_use = 1; qe->posted = 0; qe->barrier = barrier; qe->key_duplicate_per_rank = key_duplicate_per_rank; if (i == 0) { qe->nd = nd; } else { qe->nd = NULL; } ucs_atomic_add64( &queue_size[thread_ctx_id % worker_ucc_opts.num_progress_threads], 1); } return DOCA_SUCCESS; req_destroy: ucc_collective_finalize(coll_req); fail: notif->status = status; ucc_worker_safe_push_notification(ucc_worker, nd); return status; } /* * Handle UCC collective init command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_coll_init(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { size_t size = 0; size_t size_mod = 8; void *old_dest = NULL; void *packed_key = NULL; struct urom_worker_cmd *cmd = (struct urom_worker_cmd *)&cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi = NULL; int in_place = 0; ucc_worker_key_buf *key_duplicate_per_rank; ucc_worker_key_buf *keys; uint64_t ctx_id, myrank, lvalue, i; ucc_team_h team; void *work_buffer; doca_error_t status; ucc_coll_req_h coll_req; ucc_status_t ucc_status; ucc_coll_args_t *coll_args; struct ucc_queue_element *qe; struct urom_worker_notify *notif; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; /* Prepare notification */ nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) { return DOCA_ERROR_NO_MEMORY; } nd->dest_id = cmd_desc->dest_id; notif = (struct urom_worker_notify *) &nd->worker_notif; notif->type = cmd->type; notif->urom_context = cmd->urom_context; notif->len = sizeof(*ucc_notif); ucc_notif = (struct urom_worker_notify_ucc *) notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_COLLECTIVE_COMPLETE; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; if (ucc_cmd->coll_cmd.team == NULL) { DOCA_LOG_ERR("Attempting to perform UCC collective without a UCC team"); status = DOCA_ERROR_INVALID_VALUE; goto fail; } status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup command destination"); goto fail; } if (ucc_cmd->coll_cmd.work_buffer_size > 0 && ucc_cmd->coll_cmd.work_buffer) { work_buffer = ucc_cmd->coll_cmd.work_buffer; } else { work_buffer = NULL; } team = ucc_cmd->coll_cmd.team; coll_args = ucc_cmd->coll_cmd.coll_args; myrank = ucc_cmd->dpu_worker_id; COLL_CHECK(ucc_worker, ctx_id, status); if ( (coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS ) && (coll_args->flags & UCC_COLL_ARGS_FLAG_IN_PLACE) ) { in_place = 1; } if (coll_args->mask & UCC_COLL_ARGS_FIELD_CB) /* Cannot support callbacks to host data.. just won't work */ coll_args->mask = coll_args->mask & (~UCC_COLL_ARGS_FIELD_CB); if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALL || coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV) { if (!ucc_cmd->coll_cmd.use_xgvmi) { size_mod = urom_worker_get_dt_size(coll_args->src.info.datatype); size = coll_args->src.info.count * size_mod; if (coll_args->mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER) { /* Perform get based on passed information */ keys = work_buffer; status = ucc_rma_get_host( ucc_worker->ucc_data[ctx_id].local_work_buffer, coll_args->src.info.buffer, size, ctx_id, keys->rkeys, ucc_worker); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("UCC component unable to obtain source buffer"); goto fail; } packed_key = keys->rkeys + keys->src_len; } else { /* Perform get based on domain information */ status = ucc_rma_get( ucc_worker->ucc_data[ctx_id].local_work_buffer, coll_args->src.info.buffer, size, MAX_HOST_DEST_ID, myrank, ctx_id, ucc_worker); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("UCC component unable to obtain source buffer"); goto fail; } } coll_args->src.info.buffer = ucc_worker->ucc_data[ctx_id].local_work_buffer; old_dest = coll_args->dst.info.buffer; coll_args->dst.info.buffer = ucc_worker->ucc_data[ctx_id].local_work_buffer + size; } if (!(coll_args->mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER) || !work_buffer) { coll_args->mask |= UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; coll_args->flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; coll_args->global_work_buffer = ucc_worker->ucc_data[ctx_id].pSync + (ucc_worker->ucc_data[ctx_id].psync_offset % worker_ucc_opts.num_psync); ucc_worker->ucc_data[ctx_id].psync_offset++; } else { if (work_buffer != NULL) { coll_args->global_work_buffer = work_buffer; } } } else if (coll_args->coll_type == UCC_COLL_TYPE_ALLREDUCE || coll_args->coll_type == UCC_COLL_TYPE_ALLGATHER) { if (!ucc_cmd->coll_cmd.use_xgvmi) { DOCA_LOG_ERR("Failed to initialize UCC collective:" "Allreduce must use xgvmi"); status = DOCA_ERROR_DRIVER; goto fail; } if (!(coll_args->mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER) || !work_buffer) { DOCA_LOG_ERR("Failed to initialize UCC collective:" "Allreduce must use global work buffer"); status = DOCA_ERROR_DRIVER; goto fail; } keys = work_buffer; gwbi = malloc(sizeof(ucc_tl_ucp_allreduce_sw_global_work_buf_info_t)); if (gwbi == NULL) { DOCA_LOG_ERR("Failed to initialize UCC collective: " "Couldnt malloc global work buffer"); status = DOCA_ERROR_DRIVER; goto fail; } gwbi->packed_src_memh = keys->rkeys; gwbi->packed_dst_memh = keys->rkeys + keys->src_len; key_duplicate_per_rank = malloc(sizeof(ucc_worker_key_buf) * worker_ucc_opts.tpp); if (key_duplicate_per_rank == NULL) { printf("couldnt malloc key_duplicate_per_rank\n"); } for (i = 0; i < worker_ucc_opts.tpp; i++) { memcpy(key_duplicate_per_rank[i].rkeys, keys->rkeys, keys->src_len + keys->dst_len); key_duplicate_per_rank[i].src_len = keys->src_len; key_duplicate_per_rank[i].dst_len = keys->dst_len; } status = post_nthreads_colls( ctx_id, ucc_worker, coll_args, team, myrank, in_place, gwbi, nd, cmd_desc, notif, key_duplicate_per_rank); return status; } ucc_status = ucc_collective_init(coll_args, &coll_req, team); if (UCC_OK != ucc_status) { DOCA_LOG_ERR("Failed to initialize UCC collective: %s", ucc_status_string(ucc_status)); status = DOCA_ERROR_DRIVER; goto fail; } ucc_status = ucc_collective_post(coll_req); if (UCC_OK != ucc_status) { DOCA_LOG_ERR("Failed to post UCC collective: %s", ucc_status_string(ucc_status)); status = DOCA_ERROR_DRIVER; goto req_destroy; } lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); while (lvalue != 0) { lvalue = ucs_atomic_cswap64(&queue_lock, 0, 1); } status = find_qe_slot(ctx_id, ucc_worker, &qe); lvalue = ucs_atomic_cswap64(&queue_lock, 1, 0); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find queue slot for team creation"); goto req_destroy; } qe->type = UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE; qe->coll_req = coll_req; qe->myrank = myrank; qe->dest_id = cmd_desc->dest_id; if (!ucc_cmd->coll_cmd.use_xgvmi) { DOCA_LOG_DBG("Setting old dest to %p", old_dest); qe->old_dest = old_dest; qe->data_size = size; } else { qe->old_dest = NULL; qe->data_size = 0; } qe->gwbi = gwbi; qe->dest_packed_key = packed_key; qe->ctx_id = ctx_id; qe->in_use = 1; qe->posted = 1; qe->barrier = NULL; qe->nd = nd; ucs_atomic_add64(&queue_size[ctx_id % worker_ucc_opts.num_progress_threads], 1); return DOCA_SUCCESS; req_destroy: ucc_collective_finalize(coll_req); fail: notif->status = status; ucc_worker_safe_push_notification(ucc_worker, nd); return status; } /* * Handle UCC passive data channel create command * * @ucc_worker [in]: UCC worker context * @cmd_desc [in]: UCC command descriptor * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_pass_dc_create(struct urom_worker_ucc *ucc_worker, struct urom_worker_cmd_desc *cmd_desc) { struct urom_worker_cmd *cmd = (struct urom_worker_cmd *) &cmd_desc->worker_cmd; struct urom_worker_ucc_cmd *ucc_cmd = (struct urom_worker_ucc_cmd *) cmd->plugin_cmd; uint64_t ctx_id; ucp_ep_h new_ep; doca_error_t status; ucs_status_t ucs_status; ucp_ep_params_t ep_params; struct urom_worker_notify *notif; struct urom_worker_notif_desc *nd; struct urom_worker_notify_ucc *ucc_notif; /* Prepare notification */ nd = calloc(1, sizeof(*nd) + sizeof(*ucc_notif)); if (nd == NULL) return DOCA_ERROR_NO_MEMORY; nd->dest_id = cmd_desc->dest_id; notif = (struct urom_worker_notify *)&nd->worker_notif; notif->type = cmd->type; notif->urom_context = cmd->urom_context; notif->len = sizeof(*ucc_notif); ucc_notif = (struct urom_worker_notify_ucc *) notif->plugin_notif; ucc_notif->notify_type = UROM_WORKER_NOTIFY_UCC_PASSIVE_DATA_CHANNEL_COMPLETE; ucc_notif->dpu_worker_id = ucc_cmd->dpu_worker_id; status = worker_ucc_dest_lookup(ucc_worker, cmd_desc->dest_id, &ctx_id); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup command destination"); goto fail; } if (ucc_worker->ucc_data[ctx_id].host == NULL) { ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLER | UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; ep_params.err_handler.cb = urom_ep_err_cb; ep_params.err_handler.arg = NULL; ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; ep_params.address = ucc_cmd->pass_dc_create_cmd.ucp_addr; ucs_status = ucp_ep_create(ucc_worker->ucp_data.ucp_worker, &ep_params, &new_ep); if (ucs_status != UCS_OK) { DOCA_LOG_ERR("ucp_ep_create() returned: %s", ucs_status_string(ucs_status)); status = DOCA_ERROR_DRIVER; goto fail; } ucc_worker->ucc_data[ctx_id].host = new_ep; DOCA_LOG_DBG("Created passive data channel for host for rank %lu", ucc_cmd->dpu_worker_id); } else { DOCA_LOG_DBG("Passive data channel already created"); } status = DOCA_SUCCESS; fail: notif->status = status; ucc_worker_safe_push_notification(ucc_worker, nd); return status; } /* * Handle UROM UCC worker commands function * * @ctx [in]: DOCA UROM worker context * @cmd_list [in]: command descriptor list to handle * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_worker_cmd(struct urom_worker_ctx *ctx, ucs_list_link_t *cmd_list) { doca_error_t status = DOCA_SUCCESS; struct urom_worker_ucc *ucc_worker = (struct urom_worker_ucc *) ctx->plugin_ctx; struct urom_worker_ucc_cmd *ucc_cmd; struct urom_worker_cmd_desc *cmd_desc; struct urom_worker_cmd *cmd; while (!ucs_list_is_empty(cmd_list)) { cmd_desc = ucs_list_extract_head(cmd_list, struct urom_worker_cmd_desc, entry); status = urom_worker_ucc_cmd_unpack(&cmd_desc->worker_cmd, cmd_desc->worker_cmd.len, &cmd); if (status != DOCA_SUCCESS) { free(cmd_desc); return status; } ucc_cmd = (struct urom_worker_ucc_cmd *)cmd->plugin_cmd; switch (ucc_cmd->cmd_type) { case UROM_WORKER_CMD_UCC_LIB_CREATE: status = urom_worker_ucc_lib_create(ucc_worker, cmd_desc); break; case UROM_WORKER_CMD_UCC_LIB_DESTROY: status = urom_worker_ucc_lib_destroy(ucc_worker, cmd_desc); break; case UROM_WORKER_CMD_UCC_CONTEXT_CREATE: status = urom_worker_ucc_context_create(ucc_worker, cmd_desc); break; case UROM_WORKER_CMD_UCC_CONTEXT_DESTROY: status = urom_worker_ucc_context_destroy(ucc_worker, cmd_desc); break; case UROM_WORKER_CMD_UCC_TEAM_CREATE: status = urom_worker_ucc_team_create(ucc_worker, cmd_desc); break; case UROM_WORKER_CMD_UCC_COLL: status = urom_worker_ucc_coll_init(ucc_worker, cmd_desc); break; case UROM_WORKER_CMD_UCC_CREATE_PASSIVE_DATA_CHANNEL: status = urom_worker_ucc_pass_dc_create(ucc_worker, cmd_desc); break; default: DOCA_LOG_INFO("Invalid UCC command type: %u", ucc_cmd->cmd_type); status = DOCA_ERROR_INVALID_VALUE; break; } free(cmd_desc); if (status != DOCA_SUCCESS) { return status; } } return status; } /* * Get UCC worker address * * UROM worker calls the function twice, first one to get address length and second one to get address data * * @worker_ctx [in]: DOCA UROM worker context * @addr [out]: set worker address * @addr_len [out]: set worker address length * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_addr(struct urom_worker_ctx *worker_ctx, void *addr, uint64_t *addr_len) { struct urom_worker_ucc *ucc_worker = (struct urom_worker_ucc *) worker_ctx->plugin_ctx; ucs_status_t status; if (ucc_worker->ucp_data.worker_address == NULL) { status = ucp_worker_get_address(ucc_worker->ucp_data.ucp_worker, &ucc_worker->ucp_data.worker_address, &ucc_worker->ucp_data.ucp_addrlen); if (status != UCS_OK) { DOCA_LOG_ERR("Failed to get ucp worker address"); return DOCA_ERROR_INITIALIZATION; } } if (*addr_len < ucc_worker->ucp_data.ucp_addrlen) { /* Return required buffer size on error */ *addr_len = ucc_worker->ucp_data.ucp_addrlen; return DOCA_ERROR_INVALID_VALUE; } *addr_len = ucc_worker->ucp_data.ucp_addrlen; memcpy(addr, ucc_worker->ucp_data.worker_address, *addr_len); return DOCA_SUCCESS; } /* * Check UCC worker tasks progress to get notifications * * @ctx [in]: DOCA UROM worker context * @notif_list [out]: set notification descriptors for completed tasks * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_progress(struct urom_worker_ctx *ctx, ucs_list_link_t *notif_list) { uint64_t lvalue = 0; struct urom_worker_ucc *ucc_worker = (struct urom_worker_ucc *) ctx->plugin_ctx; struct urom_worker_notif_desc *nd; if (ucs_list_is_empty(&ucc_worker->completed_reqs)) { return DOCA_ERROR_EMPTY; } if (ucc_component_enabled) { lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); while (lvalue != 0) { lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 0, 1); } } while (!ucs_list_is_empty(&ucc_worker->completed_reqs)) { nd = ucs_list_extract_head(&ucc_worker->completed_reqs, struct urom_worker_notif_desc, entry); ucs_list_add_tail(notif_list, &nd->entry); } if (ucc_component_enabled) { lvalue = ucs_atomic_cswap64(&ucc_worker->list_lock, 1, 0); } return DOCA_SUCCESS; } /* * Packing UCC notification * * @notif [in]: UCC notification to pack * @packed_notif_len [in/out]: set packed notification command buffer size * @packed_notif [out]: set packed notification command buffer * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t urom_worker_ucc_notif_pack(struct urom_worker_notify *notif, size_t *packed_notif_len, void *packed_notif) { void *pack_tail = packed_notif; int pack_len; void *pack_head; /* Pack base command */ pack_len = ucs_offsetof(struct urom_worker_notify, plugin_notif) + sizeof(struct urom_worker_notify_ucc); pack_head = urom_ucc_serialize_next_raw(&pack_tail, void, pack_len); memcpy(pack_head, notif, pack_len); *packed_notif_len = pack_len; return DOCA_SUCCESS; } /* Define UROM UCC plugin interface, set plugin functions */ static struct urom_worker_ucc_iface urom_worker_ucc = { .super.open = urom_worker_ucc_open, .super.close = urom_worker_ucc_close, .super.addr = urom_worker_ucc_addr, .super.worker_cmd = urom_worker_ucc_worker_cmd, .super.progress = urom_worker_ucc_progress, .super.notif_pack = urom_worker_ucc_notif_pack, }; doca_error_t urom_plugin_get_iface(struct urom_plugin_iface *iface) { if (iface == NULL) { return DOCA_ERROR_INVALID_VALUE; } DOCA_STRUCT_CTOR(urom_worker_ucc.super); *iface = urom_worker_ucc.super; return DOCA_SUCCESS; } doca_error_t urom_plugin_get_version(uint64_t *version) { if (version == NULL) { return DOCA_ERROR_INVALID_VALUE; } *version = plugin_version; return DOCA_SUCCESS; } openucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/dpu/worker_ucc.h0000664000175000017500000003262515133731560025636 0ustar alastairalastair/* * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. * * This software product is a proprietary product of NVIDIA CORPORATION & * AFFILIATES (the "Company") and all right, title, and interest in and to the * software product, including all associated intellectual property rights, are * and shall remain exclusively with the Company. * * This software product is governed by the End User License Agreement * provided with the software product. * */ #ifndef WORKER_UCC_H_ #define WORKER_UCC_H_ #include #include #include #include #include #include #include "../common/urom_ucc.h" #define MAX_HOST_DEST_ID INT_MAX /* Maximum destination host id */ #define MIN_THREADS 1 /* Minimum number of threads per UCC worker */ /* Collective operation check macro */ #define COLL_CHECK(ucc_worker, ctx_id, status) \ { \ if (ucc_worker->ucc_data[ctx_id].ucc_lib == NULL) { \ DOCA_LOG_ERR("Attempting to perform ucc collective " \ "without initialization"); \ status = DOCA_ERROR_NOT_FOUND; \ goto fail; \ } \ \ if (ucc_worker->ucc_data[ctx_id].ucc_context == NULL) { \ DOCA_LOG_ERR("Attempting to perform ucc collective " \ "without a ucc context"); \ status = DOCA_ERROR_NOT_FOUND; \ goto fail; \ } \ } /* UCC serializing next raw, iter points to the offset place and returns the buffer start */ #define urom_ucc_serialize_next_raw(_iter, _type, _offset) \ ({ \ _type *_result = (_type *)(*(_iter)); \ *(_iter) = UCS_PTR_BYTE_OFFSET(*(_iter), _offset); \ _result; \ }) /* Worker UCC options */ struct worker_ucc_opts { uint64_t num_progress_threads; /* Number of threads */ uint64_t dpu_worker_binding_stride; /* Each worker thread is bound to this far apart core # from each other */ uint64_t ppw; /* Number of processes per worker */ uint64_t tpp; /* Threads per host process--create this many duplicate ucc contexts/teams/collectives per single host cmd */ uint64_t list_size; /* Size of progress list */ uint64_t num_psync; /* Number of synchronization/work scratch buffers to allocate for collectives */ }; /* UCC worker queue elements types */ enum ucc_worker_queue_element_type { UCC_WORKER_QUEUE_ELEMENT_TYPE_TEAM_CREATE, /* Team element queue type */ UCC_WORKER_QUEUE_ELEMENT_TYPE_COLLECTIVE, /* Collective element queue type */ }; /* UROM UCC worker interface */ struct urom_worker_ucc_iface { struct urom_plugin_iface super; /* DOCA UROM worker plugin interface */ }; /* UCC data structure */ struct ucc_data { ucc_lib_h ucc_lib; /* UCC lib handle */ ucc_lib_attr_t ucc_lib_attr; /* UCC lib attribute structure */ ucc_context_h ucc_context; /* UCC context */ ucc_team_h *ucc_team; /* Array of UCC team members */ int64_t n_teams; /* Array size */ long *pSync; /* Pointer to synchronization/work scratch buffers */ uint64_t psync_offset; /* Synchronization buffer offset */ void *local_work_buffer; /* Local work buffer */ size_t len; /* Buffer length */ ucp_ep_h host; /* The host data endpoint */ }; /* EP map */ KHASH_MAP_INIT_INT64(ep, ucp_ep_h); /* Memory handles map */ KHASH_MAP_INIT_INT64(memh, ucp_mem_h); /* Remote key map */ KHASH_MAP_INIT_INT64(rkeys, ucp_rkey_h); /* UCP data structure */ struct ucc_ucp_data { ucp_context_h ucp_context; /* UCP context */ ucp_worker_h ucp_worker; /* UCP worker */ ucp_address_t *worker_address; /* UCP worker address */ size_t ucp_addrlen; /* UCP worker address length */ khash_t(ep) *eps; /* EP hashtable map */ khash_t(memh) *memh; /* Memh hashtable map */ khash_t(rkeys) *rkeys; /* Rkey hashtable map */ }; /* Context ids map */ KHASH_MAP_INIT_INT64(ctx_id, uint64_t); struct urom_worker_ucc { struct urom_worker_ctx *super; /* DOCA base worker context */ struct ucc_data *ucc_data; /* UCC data structure */ struct ucc_ucp_data ucp_data; /* UCP data structure */ uint64_t list_lock; /* List lock field */ ucs_list_link_t completed_reqs; /* List of completed requests */ struct ucc_queue_element **queue; /* Elements queue */ khash_t(ctx_id) *ids; /* Ids hashtable map */ uint64_t ctx_id; /* Context id, incremented with every new dest id */ uint64_t nr_connections; /* Number of connections */ }; /* UCC worker thread args */ struct ctx_thread_args { uint64_t notif_type; /* Notification type */ uint64_t urom_context; /* UROM context */ int64_t start; /* Start index */ int64_t stride; /* Number of strides */ int64_t size; /* The work buffer size */ int64_t myrank; /* Current thread rank */ void *base_va; /* Buffer host address */ size_t len; /* Total buffer length */ uint64_t dest_id; /* Destination id */ struct urom_worker_ucc *ucc_worker; /* UCC worker structure */ }; /* UCC collective context structure */ struct coll_ctx { union { int64_t start; /* Collective start for single team */ int64_t *pids; /* Collective team pids */ }; int64_t stride; /* Number of strides */ int64_t size; /* The work buffer size */ int64_t index; /* Current collective member index */ struct urom_worker_ucc *ucc_worker; /* UCC worker context */ }; typedef struct ucc_tl_ucp_allreduce_sw_global_work_buf_info { void *packed_src_memh; void *packed_dst_memh; } ucc_tl_ucp_allreduce_sw_global_work_buf_info_t; /* UCC queue element structure */ struct ucc_queue_element { enum ucc_worker_queue_element_type type; /* Element type */ volatile int64_t in_use; /* If element in use */ volatile int64_t posted; /* If element was posted */ uint64_t dest_id; /* Element destination id */ uint64_t ctx_id; /* Element context id */ uint64_t myrank; /* Element rank */ pthread_barrier_t *barrier; /* If not null, call this barrier */ void *old_dest; /* Old element destination */ size_t data_size; /* Data size */ ucc_coll_req_h coll_req; /* UCC collective request */ struct coll_ctx *coll_ctx; /* UCC worker collective context */ uint64_t team_id; /* Team id */ void *dest_packed_key; /* Destination data packed key */ struct urom_worker_notif_desc *nd; /* Element notification descriptor */ ucc_worker_key_buf *key_duplicate_per_rank; /* per-rank copy of keys */ ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi; /* gwbi ptr */ }; /* UCC oob allgather request */ struct oob_allgather_req { void *sbuf; /* Local buffer */ void *rbuf; /* Remote buffer */ size_t msglen; /* Message length */ void *oob_coll_ctx; /* OOB collective context */ int iter; /* Interation */ int index; /* Current process index */ int *status; /* Request status */ }; /* * Execute RMA put operation for target buffer * * @buffer [in]: target buffer * @target [in]: pointer to target * @msglen [in]: message length * @dest [in]: destination id * @myrank [in]: current rank in UCC team * @ctx_id [in]: current context id * @ucc_worker [in]: UCC worker context * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t ucc_rma_put(void *buffer, void *target, size_t msglen, uint64_t dest, uint64_t myrank, uint64_t ctx_id, struct urom_worker_ucc *ucc_worker); /* * Execute RMA get operation on target buffer * * @buffer [in]: target buffer * @target [in]: pointer to target * @msglen [in]: message length * @dest [in]: destination id * @myrank [in]: current rank in UCC team * @ctx_id [in]: current context id * @ucc_worker [in]: UCC worker context * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t ucc_rma_get(void *buffer, void *target, size_t msglen, uint64_t dest, uint64_t myrank, uint64_t ctx_id, struct urom_worker_ucc *ucc_worker); /* * Execute UCP send operation * * @msg [in]: send message * @len [in]: message length * @myrank [in]: current rank in UCC team * @dest [in]: destination id * @ucc_worker [in]: UCC worker context * @req [out]: request result * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t ucc_send_nb(void *msg, size_t len, int64_t myrank, int64_t dest, struct urom_worker_ucc *ucc_worker, int *req); /* * Execute UCP recv operation * * @msg [in]: recv buffer * @len [in]: buffer length * @dest [in]: destination id * @ucc_worker [in]: UCC worker context * @req [out]: request result * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t ucc_recv_nb(void *msg, size_t len, int64_t dest, struct urom_worker_ucc *ucc_worker, int *req); /* * Execute RMA get host information * * @buffer [in]: target buffer * @target [in]: pointer to target * @msglen [in]: message length * @ctx_id [in]: context id * @packed_key [in]: packed key * @ucc_worker [in]: UCC worker context * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t ucc_rma_get_host(void *buffer, void *target, size_t msglen, uint64_t ctx_id, void *packed_key, struct urom_worker_ucc *ucc_worker); /* * Execute RMA put host information * * @buffer [in]: target buffer * @target [in]: pointer to target * @msglen [in]: message length * @ctx_id [in]: context id * @packed_key [in]: packed key * @ucc_worker [in]: UCC worker context * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t ucc_rma_put_host(void *buffer, void *target, size_t msglen, uint64_t ctx_id, void *packed_key, struct urom_worker_ucc *ucc_worker); /* * UCP endpoint error handling context * * @arg [in]: user argument * @ep [in]: EP handler * @ucs_status [in]: UCS status */ void urom_ep_err_cb(void *arg, ucp_ep_h ep, ucs_status_t ucs_status); /* * Get DOCA worker plugin interface for UCC plugin. * DOCA UROM worker will load the urom_plugin_get_iface symbol to get the UCC interface * * @iface [out]: Set DOCA UROM plugin interface for UCC * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t urom_plugin_get_iface(struct urom_plugin_iface *iface); /* * Get UCC plugin version, will be used to verify that the host and DPU plugin versions are compatible * * @version [out]: Set the UCC worker plugin version * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ doca_error_t urom_plugin_get_version(uint64_t *version); #endif /* WORKER_UCC_H_ */ openucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/dpu/worker_ucc_p2p.c0000664000175000017500000004454415133731560026415 0ustar alastairalastair/* * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED. * * This software product is a proprietary product of NVIDIA CORPORATION & * AFFILIATES (the "Company") and all right, title, and interest in and to the * software product, including all associated intellectual property rights, are * and shall remain exclusively with the Company. * * This software product is governed by the End User License Agreement * provided with the software product. * */ #include #include #include "worker_ucc.h" #include "../common/urom_ucc.h" DOCA_LOG_REGISTER(UCC::DOCA_CL : WORKER_UCC_P2P); void urom_ep_err_cb(void *arg, ucp_ep_h ep, ucs_status_t ucs_status) { (void)arg; (void)ep; DOCA_LOG_ERR("Endpoint error detected, status: %s", ucs_status_string(ucs_status)); } /* * UCC worker EP lookup function * * @ucc_worker [in]: UCC worker context * @dest [in]: destination id * @ep [out]: set UCP endpoint * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t worker_ucc_ep_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest, ucp_ep_h *ep) { int ret; khint_t k; void *addr; ucp_ep_h new_ep; doca_error_t status; ucs_status_t ucs_status; ucp_ep_params_t ep_params; k = kh_get(ep, ucc_worker->ucp_data.eps, dest); if (k != kh_end(ucc_worker->ucp_data.eps)) { *ep = kh_value(ucc_worker->ucp_data.eps, k); return DOCA_SUCCESS; } /* Create new EP */ status = doca_urom_worker_domain_addr_lookup(ucc_worker->super, dest, &addr); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Id not found in domain:: %#lx", dest); return DOCA_ERROR_NOT_FOUND; } ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLER | UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; ep_params.err_handler.cb = urom_ep_err_cb; ep_params.err_handler.arg = NULL; ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; ep_params.address = addr; ucs_status = ucp_ep_create(ucc_worker->ucp_data.ucp_worker, &ep_params, &new_ep); if (ucs_status != UCS_OK) { DOCA_LOG_ERR("ucp_ep_create() returned: %s", ucs_status_string(ucs_status)); return DOCA_ERROR_INITIALIZATION; } k = kh_put(ep, ucc_worker->ucp_data.eps, dest, &ret); if (ret <= 0) { return DOCA_ERROR_DRIVER; } kh_value(ucc_worker->ucp_data.eps, k) = new_ep; *ep = new_ep; DOCA_LOG_DBG("Created EP for dest: %#lx", dest); return DOCA_SUCCESS; } /* * UCC worker memh lookup function * * @ucc_worker [in]: UCC worker context * @dest [in]: destination id * @memh [out]: set memory handle * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t worker_ucc_memh_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest, ucp_mem_h *memh) { ucp_mem_map_params_t mmap_params = {0}; size_t memh_len = 0; int ret; khint_t k; void *mem_handle; ucp_mem_h memh_id; doca_error_t status; ucs_status_t ucs_status; k = kh_get(memh, ucc_worker->ucp_data.memh, dest); if (k != kh_end(ucc_worker->ucp_data.memh)) { *memh = kh_value(ucc_worker->ucp_data.memh, k); return DOCA_SUCCESS; } /* Lookup memory handle */ status = doca_urom_worker_domain_memh_lookup(ucc_worker->super, dest, 0, &memh_len, &mem_handle); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Id not found in domain:: %#lx", dest); return DOCA_ERROR_NOT_FOUND; } mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER; mmap_params.exported_memh_buffer = mem_handle; ucs_status = ucp_mem_map(ucc_worker->ucp_data.ucp_context, &mmap_params, &memh_id); if (ucs_status != UCS_OK) { DOCA_LOG_ERR("Failed to map packed memh %p", mem_handle); return DOCA_ERROR_NOT_FOUND; } k = kh_put(memh, ucc_worker->ucp_data.memh, dest, &ret); if (ret <= 0) { DOCA_LOG_ERR("Failed to add memh to hashtable map"); if (ucp_mem_unmap(ucc_worker->ucp_data.ucp_context, memh_id) != UCS_OK) { DOCA_LOG_ERR("Failed to unmap memh"); } return DOCA_ERROR_DRIVER; } kh_value(ucc_worker->ucp_data.memh, k) = memh_id; *memh = memh_id; DOCA_LOG_DBG("Assigned memh %p for dest: %#lx", memh_id, dest); return DOCA_SUCCESS; } /* * UCC worker memory key lookup function * * @ucc_worker [in]: UCC worker context * @dest [in]: destination id * @ep [in]: destination endpoint * @va [in]: memory host address * @ret_rkey [out]: set remote memory key * @return: DOCA_SUCCESS on success and DOCA_ERROR otherwise */ static doca_error_t worker_ucc_key_lookup(struct urom_worker_ucc *ucc_worker, uint64_t dest, ucp_ep_h ep, uint64_t va, void **ret_rkey) { khint_t k; int ret; void *packed_key; size_t packed_key_len; ucp_rkey_h rkey; doca_error_t status; ucs_status_t ucs_status; int seg; k = kh_get(rkeys, ucc_worker->ucp_data.rkeys, dest); if (k != kh_end(ucc_worker->ucp_data.rkeys)) { *ret_rkey = kh_value(ucc_worker->ucp_data.rkeys, k); return DOCA_SUCCESS; } status = doca_urom_worker_domain_seg_lookup(ucc_worker->super, dest, va, &seg); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Id not found in domain: %#lx", dest); return DOCA_ERROR_NOT_FOUND; } status = doca_urom_worker_domain_mkey_lookup(ucc_worker->super, dest, seg, &packed_key_len, &packed_key); if (status != DOCA_SUCCESS) { DOCA_LOG_ERR("Id not found in domain: %#lx", dest); return DOCA_ERROR_NOT_FOUND; } ucs_status = ucp_ep_rkey_unpack(ep, packed_key, &rkey); if (ucs_status != UCS_OK) { return DOCA_ERROR_NOT_FOUND; } k = kh_put(rkeys, ucc_worker->ucp_data.rkeys, dest, &ret); if (ret <= 0) { return DOCA_ERROR_DRIVER; } kh_value(ucc_worker->ucp_data.rkeys, k) = rkey; *ret_rkey = rkey; DOCA_LOG_DBG("Assigned rkey for dest: %#lx", dest); return DOCA_SUCCESS; } /* * UCC send tag completion callback * * @request [in]: UCP send request * @status [in]: send task status * @user_data [in]: UCC data */ static void send_completion_cb(void *request, ucs_status_t status, void *user_data) { int *req = (int *)user_data; if (status != UCS_OK) { *req = -1; } else { *req = 1; } ucp_request_free(request); } /* * UCC recv tag completion callback * * @request [in]: UCP recv request * @status [in]: recv task status * @info [in]: recv task info * @user_data [in]: UCC data */ static void recv_completion_cb(void *request, ucs_status_t status, const ucp_tag_recv_info_t *info, void *user_data) { int *req = (int *)user_data; (void)info; if (status != UCS_OK) { *req = -1; } else { *req = 1; } ucp_request_free(request); } doca_error_t ucc_send_nb(void *msg, size_t len, int64_t myrank, int64_t dest, struct urom_worker_ucc *ucc_worker, int *req) { ucp_ep_h ep = NULL; ucp_request_param_t req_param = {0}; doca_error_t urom_status; ucs_status_ptr_t ucp_status; *req = 0; req_param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; req_param.datatype = ucp_dt_make_contig(len); req_param.cb.send = send_completion_cb; req_param.user_data = (void *)req; urom_status = worker_ucc_ep_lookup(ucc_worker, dest, &ep); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to send to %ld in UCC oob", dest); return DOCA_ERROR_NOT_FOUND; } /* Process tag send */ ucp_status = ucp_tag_send_nbx(ep, msg, 1, myrank, &req_param); if (ucp_status != UCS_OK) { if (UCS_PTR_IS_ERR(ucp_status)) { ucp_request_cancel(ucc_worker->ucp_data.ucp_worker, ucp_status); ucp_request_free(ucp_status); return DOCA_ERROR_NOT_FOUND; } } else { *req = 1; } return DOCA_SUCCESS; } doca_error_t ucc_recv_nb(void *msg, size_t len, int64_t dest, struct urom_worker_ucc *ucc_worker, int *req) { ucp_request_param_t req_param = {}; ucs_status_ptr_t ucp_status; *req = 0; req_param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; req_param.datatype = ucp_dt_make_contig(len); req_param.cb.recv = recv_completion_cb; req_param.user_data = (void *)req; /* Process tag recv */ ucp_status = ucp_tag_recv_nbx(ucc_worker->ucp_data.ucp_worker, msg, 1, dest, 0xffff, &req_param); if (ucp_status != UCS_OK) { if (UCS_PTR_IS_ERR(ucp_status)) { ucp_request_cancel(ucc_worker->ucp_data.ucp_worker, ucp_status); ucp_request_free(ucp_status); return DOCA_ERROR_NOT_FOUND; } } else { *req = 1; } return DOCA_SUCCESS; } doca_error_t ucc_rma_put(void *buffer, void *target, size_t msglen, uint64_t dest, uint64_t myrank, uint64_t ctx_id, struct urom_worker_ucc *ucc_worker) { uint64_t rva = (uint64_t)target; ucp_request_param_t req_param = {0}; ucp_mem_h memh = NULL; ucp_ep_h ep; ucp_rkey_h rkey; doca_error_t urom_status; ucs_status_ptr_t ucp_status; if (dest == MAX_HOST_DEST_ID) { ep = ucc_worker->ucc_data[ctx_id].host; } else { urom_status = worker_ucc_ep_lookup(ucc_worker, dest, &ep); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find peer %ld to complete collective", dest); return DOCA_ERROR_NOT_FOUND; } } if (dest != MAX_HOST_DEST_ID) { urom_status = worker_ucc_memh_lookup(ucc_worker, dest, &memh); if (urom_status != DOCA_SUCCESS) DOCA_LOG_ERR("Failed to lookup key for peer %ld", dest); } if (dest == MAX_HOST_DEST_ID) { urom_status = worker_ucc_key_lookup(ucc_worker, myrank, ep, rva, (void **)&rkey); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); return DOCA_ERROR_NOT_FOUND; } } else { urom_status = worker_ucc_key_lookup(ucc_worker, dest, ep, rva, (void **)&rkey); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); return DOCA_ERROR_NOT_FOUND; } } if (memh != NULL) { req_param.op_attr_mask = UCP_OP_ATTR_FIELD_MEMH; req_param.memh = memh; } ucp_status = ucp_put_nbx(ep, buffer, msglen, rva, rkey, &req_param); if (ucp_status != UCS_OK) { if (UCS_PTR_IS_ERR(ucp_status)) { ucp_request_free(ucp_status); return DOCA_ERROR_NOT_FOUND; } while (ucp_request_check_status(ucp_status) == UCS_INPROGRESS) { ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); } ucp_request_free(ucp_status); } return DOCA_SUCCESS; } doca_error_t ucc_rma_get(void *buffer, void *target, size_t msglen, uint64_t dest, uint64_t myrank, uint64_t ctx_id, struct urom_worker_ucc *ucc_worker) { ucp_ep_h ep = NULL; ucp_mem_h memh = NULL; ucp_rkey_h rkey = NULL; ucp_request_param_t req_param = {0}; uint64_t rva = (uint64_t)target; doca_error_t urom_status; ucs_status_ptr_t ucp_status; if (dest == MAX_HOST_DEST_ID) { ep = ucc_worker->ucc_data[ctx_id].host; } else { urom_status = worker_ucc_ep_lookup(ucc_worker, dest, &ep); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to find peer %ld to complete collective", dest); return DOCA_ERROR_NOT_FOUND; } } if (dest != MAX_HOST_DEST_ID) { urom_status = worker_ucc_memh_lookup(ucc_worker, dest, &memh); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup key for peer %ld", dest); return DOCA_ERROR_NOT_FOUND; } } if (dest == MAX_HOST_DEST_ID) { urom_status = worker_ucc_key_lookup(ucc_worker, myrank, ep, rva, (void **)&rkey); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); return DOCA_ERROR_NOT_FOUND; } } else { urom_status = worker_ucc_key_lookup(ucc_worker, dest, ep, rva, (void **)&rkey); if (urom_status != DOCA_SUCCESS) { DOCA_LOG_ERR("Failed to lookup rkey for peer %ld", dest); return DOCA_ERROR_NOT_FOUND; } } if (memh != NULL) { req_param.op_attr_mask = UCP_OP_ATTR_FIELD_MEMH; req_param.memh = memh; } ucp_status = ucp_get_nbx(ep, buffer, msglen, rva, rkey, &req_param); if (ucp_status != UCS_OK) { if (UCS_PTR_IS_ERR(ucp_status)) { ucp_request_free(ucp_status); ucp_rkey_destroy(rkey); return DOCA_ERROR_NOT_FOUND; } while (ucp_request_check_status(ucp_status) == UCS_INPROGRESS) { ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); } ucp_request_free(ucp_status); } ucp_rkey_destroy(rkey); return DOCA_SUCCESS; } doca_error_t ucc_rma_get_host(void *buffer, void *target, size_t msglen, uint64_t ctx_id, void *packed_key, struct urom_worker_ucc *ucc_worker) { ucp_ep_h ep = NULL; ucp_rkey_h rkey = NULL; uint64_t rva = (uint64_t)target; ucp_request_param_t req_param = {0}; ucs_status_t ucs_status; ucs_status_ptr_t ucp_status; if (packed_key == NULL) { return DOCA_ERROR_INVALID_VALUE; } ep = ucc_worker->ucc_data[ctx_id].host; ucs_status = ucp_ep_rkey_unpack(ep, packed_key, &rkey); if (ucs_status != UCS_OK) { DOCA_LOG_ERR("Failed to unpack rkey"); return DOCA_ERROR_NOT_FOUND; } ucp_status = ucp_get_nbx(ep, buffer, msglen, rva, rkey, &req_param); if (UCS_OK != ucp_status) { if (UCS_PTR_IS_ERR(ucp_status)) { DOCA_LOG_ERR("Failed to perform ucp_get_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); ucp_request_free(ucp_status); ucp_rkey_destroy(rkey); return DOCA_ERROR_NOT_FOUND; } while (UCS_INPROGRESS == ucp_request_check_status(ucp_status)) { ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); } if (UCS_PTR_IS_ERR(ucp_status)) { DOCA_LOG_ERR("Failed to perform ucp_get_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); ucp_request_free(ucp_status); ucp_rkey_destroy(rkey); return DOCA_ERROR_NOT_FOUND; } ucp_request_free(ucp_status); } ucp_rkey_destroy(rkey); return DOCA_SUCCESS; } doca_error_t ucc_rma_put_host(void *buffer, void *target, size_t msglen, uint64_t ctx_id, void *packed_key, struct urom_worker_ucc *ucc_worker) { ucp_ep_h ep = NULL; ucp_rkey_h rkey = NULL; uint64_t rva = (uint64_t)target; ucp_request_param_t req_param = {0}; ucs_status_t ucs_status; ucs_status_ptr_t ucp_status; if (packed_key == NULL) { return DOCA_ERROR_INVALID_VALUE; } ep = ucc_worker->ucc_data[ctx_id].host; ucs_status = ucp_ep_rkey_unpack(ep, packed_key, &rkey); if (ucs_status != UCS_OK) { DOCA_LOG_ERR("Failed to unpack rkey"); return DOCA_ERROR_NOT_FOUND; } ucp_status = ucp_put_nbx(ep, buffer, msglen, rva, rkey, &req_param); if (UCS_OK != ucp_status) { if (UCS_PTR_IS_ERR(ucp_status)) { DOCA_LOG_ERR("Failed to perform ucp_put_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); ucp_request_free(ucp_status); ucp_rkey_destroy(rkey); return DOCA_ERROR_NOT_FOUND; } while (UCS_INPROGRESS == ucp_request_check_status(ucp_status)) { ucp_worker_progress(ucc_worker->ucp_data.ucp_worker); } if (UCS_PTR_IS_ERR(ucp_status)) { DOCA_LOG_ERR("Failed to perform ucp_put_nbx(): %s\n", ucs_status_string(UCS_PTR_STATUS(ucp_status))); ucp_request_free(ucp_status); ucp_rkey_destroy(rkey); return DOCA_ERROR_NOT_FOUND; } ucp_request_free(ucp_status); } ucp_rkey_destroy(rkey); return DOCA_SUCCESS; } openucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/Makefile.am0000664000175000017500000000135515133731560024562 0ustar alastairalastair# # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # if HAVE_DOCA_UROM sources = \ common/urom_ucc.h \ dpu/worker_ucc_p2p.c \ dpu/worker_ucc.h \ dpu/worker_ucc.c plugindir = $(moduledir)/doca_plugins plugin_LTLIBRARIES = libucc_doca_urom_plugin.la libucc_doca_urom_plugin_la_SOURCES = $(sources) libucc_doca_urom_plugin_la_CPPFLAGS = $(AM_CPPFLAGS) $(BASE_CPPFLAGS) $(UCX_CPPFLAGS) $(DOCA_UROM_CPPFLAGS) libucc_doca_urom_plugin_la_CFLAGS = $(BASE_CFLAGS) libucc_doca_urom_plugin_la_LDFLAGS = -version-info $(SOVERSION) --as-needed $(UCX_LDFLAGS) $(DOCA_UROM_LDFLAGS) libucc_doca_urom_plugin_la_LIBADD = $(UCX_LIBADD) $(DOCA_UROM_LIBADD) $(UCC_TOP_BUILDDIR)/src/libucc.la endif openucx-ucc-ec0bc8a/contrib/doca_urom_ucc_plugin/DOCA_UROM_README.md0000664000175000017500000001214615133731560025435 0ustar alastairalastair# How to enable and run collectives on NVIDIA DPUs The CL_DOCA_UROM plugin enables collective offloads via DOCA and CL_DOCA_UROM feature. The plugin runs on the DPU and is initialized via the command channel from CL_DOCA_UROM. It leverages UCC collectives optimized for the DPU and facilitates efficient computation and communication overlap. While the plugin supports all UCC collectives via a copy-in/out method, the optimized algorithms are Allreduce and Alltoall/v (OpenSHMEM). ## Components and Dependencies ### Components 1. **Host Node:** - Runs the main UCC application and initiates communication with the DPU. - Executes scripts to run the UCC collective operations with DOCA integration. 2. **DPU:** - Runs the DOCA services. - Leverages UCC for data handling. - Communicates with the host node using the DOCA interface. ### Interaction - The host node communicates with the DPU via the DOCA interface. - Data and commands are sent from the host to the DPU, processed by UCC, and results are returned. ### Dependencies - **Programming Languages:** C/C++, Shell scripting - **Libraries:** UCC, UCX (Unified Communication X) - **Tools:** DOCA SDK, OpenMPI - **Platforms:** Host node (x86 or ARM), DPU (NVIDIA BlueField) ### Interface Design - **APIs:** UCC collective operations with DOCA_UROM extensions. ## Build and Deployment For DOCA instructions please see: [DOCA - Get Started | NVIDIA Developer](https://docs.nvidia.com/doca/sdk/nvidia+doca+overview/index.html) ### Host Node 1. **Build Instructions:** ``` Checkout UCC ./autogen.sh ./configure --prefix= --with-ucx= --with-doca_urom= make -j install ``` ### DPU 1. **Build Instructions:** - Checkout the same UCC branch as on the host. - Use the same configure line as the host. ### Generate the following files in $TOP 1. **MPI HOST hostfile:** ```plaintext hostA slots=1 hostB slots=1 ``` 2. **MPI DPU hostfile:** ```plaintext dpuA slots=1 dpuB slots=1 ``` 3. **DPU servicefile:** ```plaintext hostA dpu dpuA 0000:04:00.0 mlx5_0:1 hostB dpu dpuB 0000:04:00.0 mlx5_0:1 ``` ## Run Instructions ### 1. Run on DPU: - **Script:** `run_doca_uromd.sh` ```bash #!/bin/bash TOP= OMPI_DIR= DPU_HOST_FILE= NUMBER_OF_NODES=$(cat $DPU_FILE | grep -v '#' | wc -l) # Set plugin path export UROM_PLUGIN_PATH= export PATH=$OMPI_DIR/bin:$PATH export LD_LIBRARY_PATH=$OMPI_DIR/lib:$LD_LIBRARY_PATH export UCC_DIR= export UCX_DIR= export LD_LIBRARY_PATH=$UCC_DIR/lib:$UCX_DIR/lib:$LD_LIBRARY_PATH options="-x UCX_TLS=rc_x,tcp $options" mpirun --tag-output -np $NUMBER_OF_NODES -hostfile $DPU_HOST_FILE -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH $options -x UROM_PLUGIN_PATH=$UROM_PLUGIN_PATH $TOP/doca/build-dpu/services/urom/doca_urom_daemon -l 10 --sdk-log-level 10 ``` ### 2. Run on Host: - **Script:** `run_doca_urom_cl.sh` ```bash #!/bin/bash PPN=$1 if [ -z "$PPN" ]; then echo PPN not set, assuming PPN=1 PPN=1 fi TOP= NUMBER_OF_NODES=$(cat $TOP/hostfile | grep -v '#' | wc -l) OMPI_DIR= UCX_DIR= export LD_LIBRARY_PATH=$OMPI_DIR/lib:$UCX_DIR/lib:$UCX/lib/ucx:$TOP//doca/lib64:$TOP/doca_urom_ucc/install/host/lib64:$LD_LIBRARY_PATH OMB_DIR= # DPU options options="$options -x UCX_NET_DEVICES=mlx5_0:1,mn0" options="$options -x DOCA_UROM_SERVICE_FILE=$TOP/servicefile" options="$options -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH" # `UCC_CL_DOCA_UROM_PLUGIN_ENVS` takes a comma separated list of envs options="$options -x UCC_CL_DOCA_UROM_PLUGIN_ENVS=LD_LIBRARY_PATH=$TOP/arm/build-arm/ompi/lib:$TOP/arm/build-arm/ucx/lib:$TOP/arm/build-arm/ucc/lib:$TOP/doca_urom_ucc/install/dpu/lib64,UCX_LOG_LEVEL=ERROR" options="$options --mca coll_ucc_enable 1 --mca coll_ucc_priority 100" options="$options -x UCX_TLS=rc_x,tcp" validation="--validation" #Optional argument to run data validation in osu benchmark msg_low=1024 msg_high=$((1024*1024)) for coll in allreduce do $OMPI/bin/mpirun -np $((NODES * PPN)) --map-by node -hostfile $TOP/hostfile $options --mca coll_ucc_cls doca_urom,basic --tag-output $OMB/osu_$coll $validation "-m $msg_low:$msg_high" -i 10 -x 5 done ``` ## Key Environment Variables ### Host Side - `UCC_CL_DOCA_UROM_PLUGIN_ENVS`: Ensures the plugin has the correct `LD_LIBRARY_PATH`. - `DOCA_UROM_SERVICE_FILE`: Maps between host and DPUs. ### DPU Side - `UROM_PLUGIN_PATH`: Path to the directory containing only `.so` files that are plugins. A plugin must have the symbols `urom_plugin_get_version` and `urom_plugin_get_iface`. ## Conclusion The CL_DOCA_UROM feature and the corresponding plugin integrate the UCC collective library with DOCA running on DPUs, leveraging optimized algorithms for the DPU. This README outlines the components, interaction, build, deployment, and execution instructions necessary to implement, run, and extend this feature. openucx-ucc-ec0bc8a/contrib/ucc.conf0000664000175000017500000001611515133731560017767 0ustar alastairalastair# Default TLS configuration # We mostly use "negate" interface so that default TL config # never throws warnings if some TLs are not available # Currently compiled tls: ucp,cuda,nccl,sharp # Default for CL_BASIC: all except sharp,nccl. # cuda will silently disqualify itself for multinode teams # but will be used on a single node UCC_CL_BASIC_TLS=^sharp,nccl # Defaults for CL_HIER: set per SBGP # Sharp should be explicitly enabled UCC_CL_HIER_NODE_SBGP_TLS=^sharp,nccl # cuda is also disabled for NODE_LEADERS and NET UCC_CL_HIER_NODE_LEADERS_SBGP_TLS=^sharp,nccl,cuda UCC_CL_HIER_NET_SBGP_TLS=^sharp,nccl,cuda # FULL_SBGP is currently only used for hierarchical alltoall # with ucp sbgp on top UCC_CL_HIER_FULL_SBGP_TLS=ucp # Tuning sections, currently only supports TL/UCP #Intel Broadwell: [vendor=intel model=broadwell team_size=28 ppn=28 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=7 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=intel model=broadwell team_size=2 ppn=1 nnodes=2] UCC_TL_UCP_ALLREDUCE_KN_RADIX=2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=2 UCC_TL_UCP_TUNE=allreduce:0-128k:@0#allreduce:128k-inf:@1 [vendor=intel model=broadwell team_size=4 ppn=1 nnodes=4] UCC_TL_UCP_ALLREDUCE_KN_RADIX=4 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4 UCC_TL_UCP_TUNE=allreduce:0-16k:@0#allreduce:16k-inf:@1 [vendor=intel model=broadwell team_size=8 ppn=1 nnodes=8] UCC_TL_UCP_ALLREDUCE_KN_RADIX=8 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=8 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 #Intel Skylake: [vendor=intel model=skylake team_size=40 ppn=40 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=0-128k:host:2,128k-inf:host:8 UCC_TL_UCP_TUNE=allreduce:0-2k:@0#allreduce:2k-inf:@1 [vendor=intel model=skylake team_size=2 ppn=1 nnodes=2] UCC_TL_UCP_ALLREDUCE_KN_RADIX=2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=2 UCC_TL_UCP_TUNE=allreduce:0-16k:@0#allreduce:16k-inf:@1 [vendor=intel model=skylake team_size=4 ppn=1 nnodes=4] UCC_TL_UCP_ALLREDUCE_KN_RADIX=4 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4 UCC_TL_UCP_TUNE=allreduce:0-8k:@0#allreduce:8k-inf:@1 [vendor=intel model=skylake team_size=8 ppn=1 nnodes=8] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-8k:host:8,8k-inf:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=8 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=intel model=skylake team_size=32 ppn=1 nnodes=32] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-8k:host:8,8k-inf:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=8 UCC_TL_UCP_TUNE=allreduce:0-2k:@0#allreduce:2k-inf:@1 #Amd Rome: [vendor=amd model=rome team_size=128 ppn=128 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=0-64k:host:4,64k-inf:host:8 UCC_TL_UCP_TUNE=allreduce:0-1k:@0#allreduce:1k-inf:@1 [vendor=amd model=rome team_size=2 ppn=1 nnodes=2] UCC_TL_UCP_ALLREDUCE_KN_RADIX=2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=2 UCC_TL_UCP_TUNE=allreduce:0-256k:@0#allreduce:256k-inf:@1 [vendor=amd model=rome team_size=4 ppn=1 nnodes=4] UCC_TL_UCP_ALLREDUCE_KN_RADIX=4 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4 UCC_TL_UCP_TUNE=allreduce:0-16k:@0#allreduce:16k-inf:@1 [vendor=amd model=rome team_size=8 ppn=1 nnodes=8] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-8k:host:8,8k-inf:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=8 UCC_TL_UCP_TUNE=allreduce:0-8k:@0#allreduce:8k-inf:@1 #NVIDIA Grace, Generic 1 node [vendor=nvidia model=grace nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-8:host:2,8-64:host:3,64-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:96,8192-16384:host:3,16384-32768:host:8,32768-65536:host:16,65536-131072:host:32,131072-262144:host:2,262144-524288:host:3,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 #NVIDIA Grace, 2 socket (C2): [vendor=nvidia model=grace team_size=144 sock=72 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:96,8192-16384:host:4,16384-32768:host:6,32768-65536:host:18,65536-131072:host:32,131072-262144:host:72,262144-524288:host:3,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1#allgather:37748736-inf:host:@1 [vendor=nvidia model=grace team_size=128 sock=64 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:72,8192-16384:host:4,16384-32768:host:8,32768-65536:host:16,65536-131072:host:32,131072-262144:host:64,262144-524288:host:3,524288-1048576:host:3 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=64 sock=32 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:2,8192-16384:host:4,16384-32768:host:8,32768-65536:host:16,65536-131072:host:32,131072-262144:host:3,262144-524288:host:3,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=32 sock=16 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:2,8192-16384:host:4,16384-32768:host:8,32768-65536:host:16,65536-131072:host:3,131072-262144:host:2,262144-524288:host:2,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=16 sock=8 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:2,8192-16384:host:8,16384-32768:host:8,32768-65536:host:2,65536-131072:host:2,131072-262144:host:2,262144-524288:host:2,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 #NVIDIA Grace, 1 socket (CG): [vendor=nvidia model=grace team_size=72 sock=72 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-64:host:3,64-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:6,8192-16384:host:6,16384-32768:host:6,32768-65536:host:16,65536-131072:host:32,131072-262144:host:48,262144-524288:host:2,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=64 sock=64 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:2,8192-16384:host:4,16384-32768:host:8,32768-65536:host:18,65536-131072:host:32,131072-262144:host:48,262144-524288:host:2,524288-1048576:host:2 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=32 sock=32 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:2,8192-16384:host:4,16384-32768:host:8,32768-65536:host:16,65536-131072:host:144,131072-262144:host:2,262144-524288:host:2,524288-1048576:host:4 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=16 sock=16 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-4k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=4096-8192:host:2,8192-16384:host:4,16384-32768:host:8,32768-65536:host:16,65536-131072:host:3,131072-262144:host:2,262144-524288:host:4,524288-1048576:host:4 UCC_TL_UCP_TUNE=allreduce:0-4k:@0#allreduce:4k-inf:@1 [vendor=nvidia model=grace team_size=8 sock=8 nnodes=1] UCC_TL_UCP_ALLREDUCE_KN_RADIX=0-8k:host:2 UCC_TL_UCP_ALLREDUCE_SRA_KN_RADIX=8192-16384:host:2,16384-32768:host:4,32768-65536:host:16,65536-131072:host:18,131072-262144:host:6,262144-524288:host:96,524288-1048576:host:4 UCC_TL_UCP_TUNE=allreduce:0-8k:@0#allreduce:8k-inf:@1 openucx-ucc-ec0bc8a/contrib/Makefile.am0000664000175000017500000000016015133731560020373 0ustar alastairalastair# # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SUBDIRS = doca_urom_ucc_plugin openucx-ucc-ec0bc8a/CONTRIBUTING.md0000664000175000017500000000137715133731560017143 0ustar alastairalastair# Contributing to UCC ## Legal All code contributions are submitted via pull requests at the UCF project on Github. All commits containing contributions are made in accordance with either the corporate or individual contributor license agreement executed by yourself individually or by someone from your company, as applicable. By making the contribution(s) you acknowledge that they are made pursuant to such applicable agreement(s). For more information on guidelines for submissions and contributions please contact UCF at admin@ucfconsortium.org. * [Individual](https://www.openucx.org/wp-content/uploads/2019/05/ucx-individual-contributor-agrement.pdf) * [Corporate](https://www.openucx.org/wp-content/uploads/2019/05/ucx-corporate-contributor-agrement.pdf) openucx-ucc-ec0bc8a/AUTHORS0000664000175000017500000000346015133731560015755 0ustar alastairalastairAlex Margolin alex.margolin@huawei.com Alexey Rivkin arivkin@nvidia.com Anatoly Vildemanov anatolyv@nvidia.com Andrii Bilokur abilokur@nvidia.com Andy Lin 32576375+andylin-hao@users.noreply.github.com Arnaud Celermajer acelermajer@nvidia.com Artem Ryabov artemry@nvidia.com Boris Karasev boriska@nvidia.com Brad Settlemyer bws@deepcopy.org Ching-Hsiang Chu chchu@fb.com Daniel Pressler danielpr@nvidia.com Devendar Bureddy devendar@nvidia.com Edgar Gabriel Edgar.Gabriel@amd.com Evgeny Keidar ekeidar@nvidia.com Ferrol Aderholdt faderholdt@nvidia.com Geoffroy Vallee geoffroy@nvidia.com Hessam Mirsadeghi hmirsadeghi@nvidia.com Ilya Kryukov ikryukov@nvidia.com Jiri Kraus jkraus@nvidia.com Lior Paz liorpa@nvidia.com Mamzi Bayatpour mbayatpour@nvidia.com Manjunath Gorentla Venkata manjunath@nvidia.com Masaki Kozuki mkozuki@nvidia.com Mike Dubman mdubman@nvidia.com Nilesh M Negi nilesh.negi@amd.com Nick Sarkauskas nsarkauskas@nvidia.com Pavel Shamis (Pasha) shamisp@users.noreply.github.com Pedram Alizadeh pedram.alizadeh@amd.com Rob Bradford rob@robster.org.uk Sam Nordmann snordmann@nvidia.com Sergey Lebedev sergeyle@nvidia.com Shimmy Balsam sbalsam@nvidia.com Sourav Chakraborty sourav.chakraborty@nvidia.com Taekyung Heo taekyung@gatech.edu Tommy Janjusic tjanjusic@nvidia.com Valentin Petrov valentinp@nvidia.com Xiang Gao xgao@nvidia.com Yael Yacobovich yyacobovich@nvidia.com openucx-ucc-ec0bc8a/test/0000775000175000017500000000000015133731560015661 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/mpi/0000775000175000017500000000000015133731560016446 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/mpi/test_scatter.cc0000664000175000017500000000753615133731560021474 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestScatter::TestScatter(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_SCATTER, params) { int rank, size; size_t dt_size, single_rank_count; dt = params.dt; dt_size = ucc_dt_size(dt); single_rank_count = msgsize / dt_size; root = params.root; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize * size), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } if (rank == root) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize * size, mem_type)); sbuf = sbuf_mc_header->addr; if (inplace) { rbuf_mc_header = NULL; rbuf = NULL; } else { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; } } else { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; sbuf_mc_header = NULL; sbuf = NULL; } check_buf = ucc_malloc(msgsize * size, "check buf"); UCC_MALLOC_CHECK(check_buf); args.root = root; if (rank == root) { args.src.info.buffer = sbuf; args.src.info.count = single_rank_count * size; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; if (!inplace) { args.dst.info.buffer = rbuf; args.dst.info.count = single_rank_count; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; } } else { args.dst.info.buffer = rbuf; args.dst.info.count = single_rank_count; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; } UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestScatter::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t single_rank_count = msgsize / dt_size; size_t single_rank_size = single_rank_count * dt_size; int rank, size; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (rank == root) { init_buffer(sbuf, single_rank_count * size, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, sbuf, single_rank_size * size, UCC_MEMORY_TYPE_HOST, mem_type)); } return UCC_OK; } ucc_status_t TestScatter::check() { size_t single_rank_count = msgsize / ucc_dt_size(dt); size_t single_rank_size = single_rank_count * ucc_dt_size(dt); MPI_Datatype mpi_dt = ucc_dt_to_mpi(dt); MPI_Request req; int size, rank, completed; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); MPI_Iscatter(check_buf, single_rank_count, mpi_dt, (rank == root) ? MPI_IN_PLACE : check_buf, single_rank_count, mpi_dt, root, team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); if (rank == root) { if (inplace) { return compare_buffers(sbuf, check_buf, single_rank_count * size, dt, mem_type); } else { return compare_buffers( rbuf, PTR_OFFSET(check_buf, single_rank_size * rank), single_rank_count, dt, mem_type); } } else { return compare_buffers(rbuf, check_buf, single_rank_count, dt, mem_type); } } openucx-ucc-ec0bc8a/test/mpi/test_mem_map.cc0000664000175000017500000003632315133731560021436 0ustar alastairalastair/** * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" #include #include #include "components/mc/ucc_mc.h" class TestMemMap : public TestCase { private: void * test_buffer; size_t buffer_size; ucc_mem_map_mem_h memh; size_t memh_size; ucc_mem_map_mode_t mode; bool is_export_test; public: TestMemMap(ucc_test_team_t &_team, TestCaseParams ¶ms, ucc_mem_map_mode_t _mode = UCC_MEM_MAP_MODE_EXPORT) : TestCase(_team, UCC_COLL_TYPE_BARRIER, params), // Using barrier as placeholder test_buffer(nullptr), buffer_size(0), memh(nullptr), memh_size(0), mode(_mode), is_export_test(_mode == UCC_MEM_MAP_MODE_EXPORT) { buffer_size = params.msgsize; int rank; if (buffer_size == 0) { buffer_size = 1024 * 1024; // Default 1MB } if (skip_reduce(test_max_size < buffer_size, TEST_SKIP_MEM_LIMIT, team.comm)) { return; } UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, buffer_size, mem_type)); test_buffer = rbuf_mc_header->addr; UCC_MALLOC_CHECK(test_buffer); // Initialize buffer with rank-specific data MPI_Comm_rank(team.comm, &rank); memset(test_buffer, 0xAA + rank, buffer_size); } ~TestMemMap() { if (memh) { ucc_mem_unmap(&memh); } } ucc_status_t set_input(int iter_persistent = 0) override { int rank; MPI_Comm_rank(team.comm, &rank); // Initialize buffer with rank and iteration specific data memset(test_buffer, 0xAA + rank + iter_persistent, buffer_size); return UCC_OK; } ucc_status_t check() override { unsigned char *buf = (unsigned char *)test_buffer; unsigned char expected; int rank; size_t i; MPI_Comm_rank(team.comm, &rank); expected = 0xAA + rank; for (i = 0; i < buffer_size; i++) { if (buf[i] != expected) { return UCC_ERR_INVALID_PARAM; } } return UCC_OK; } void run(bool triggered) override { ucc_mem_map_params_t map_params; ucc_mem_map_t segment; int rank; MPI_Comm_rank(team.comm, &rank); /* Set up memory map parameters */ segment.address = test_buffer; segment.len = buffer_size; map_params.segments = &segment; map_params.n_segments = 1; /* Test memory map */ ucc_status_t status = ucc_mem_map(team.ctx, mode, &map_params, &memh_size, &memh); if (status != UCC_OK) { if (status == UCC_ERR_NOT_SUPPORTED || status == UCC_ERR_NOT_IMPLEMENTED) { test_skip = TEST_SKIP_NOT_SUPPORTED; return; } UCC_CHECK(status); } if (!memh) { std::cerr << "Rank " << rank << ": Memory handle is NULL" << std::endl; return; } if (memh_size == 0) { std::cerr << "Rank " << rank << ": Memory handle size is 0" << std::endl; return; } /* Verify data integrity after mapping */ UCC_CHECK(check()); /* Test unmap */ UCC_CHECK(ucc_mem_unmap(&memh)); if (memh != nullptr) { std::cerr << "Rank " << rank << ": Memory handle not NULL after unmap" << std::endl; return; } /* Verify data integrity after unmapping */ UCC_CHECK(check()); } std::string str() override { return std::string("mem_map mode=") + std::to_string(mode) + " team=" + team_str(team.type) + " buffer_size=" + std::to_string(buffer_size) + " mem_type=" + std::to_string(mem_type); } }; class TestMemMapExport : public TestMemMap { public: TestMemMapExport(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestMemMap(_team, params, UCC_MEM_MAP_MODE_EXPORT) { } static std::shared_ptr init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params); }; class TestMemMapImport : public TestMemMap { public: TestMemMapImport(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestMemMap(_team, params, UCC_MEM_MAP_MODE_IMPORT) { } static std::shared_ptr init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params); }; class TestMemMapStress : public TestCase { private: void *test_buffer; size_t buffer_size; std::vector memhs; int num_iterations = 10; public: TestMemMapStress(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_BARRIER, params), test_buffer(nullptr), buffer_size(0), num_iterations(10) { buffer_size = params.msgsize; if (buffer_size == 0) { buffer_size = 1024 * 1024; /* Default 1MB */ } if (skip_reduce(test_max_size < buffer_size, TEST_SKIP_MEM_LIMIT, team.comm)) { return; } UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, buffer_size, mem_type)); test_buffer = rbuf_mc_header->addr; UCC_MALLOC_CHECK(test_buffer); memhs.reserve(num_iterations); } ~TestMemMapStress() { for (auto memh : memhs) { if (memh) { ucc_mem_unmap(&memh); } } if (test_buffer) { ucc_free(test_buffer); } } ucc_status_t set_input(int iter_persistent = 0) override { int rank; MPI_Comm_rank(team.comm, &rank); // Initialize buffer with iteration-specific data memset(test_buffer, 0xBB + rank + iter_persistent, buffer_size); return UCC_OK; } ucc_status_t check() override { unsigned char *buf = (unsigned char *)test_buffer; unsigned char expected; size_t i; int rank; MPI_Comm_rank(team.comm, &rank); /* Verify buffer integrity */ expected = 0xBB + rank; for (i = 0; i < buffer_size; i++) { if (buf[i] != expected) { return UCC_ERR_INVALID_PARAM; } } return UCC_OK; } void run(bool triggered) override { ucc_mem_map_params_t map_params; ucc_mem_map_t segment; int rank; int i; MPI_Comm_rank(team.comm, &rank); /* Set up memory map parameters */ segment.address = test_buffer; segment.len = buffer_size; map_params.segments = &segment; map_params.n_segments = 1; /* Stress test: multiple map/unmap operations */ for (i = 0; i < num_iterations; i++) { ucc_mem_map_mem_h memh; size_t memh_size; /* Fill buffer with iteration-specific pattern */ memset(test_buffer, 0xCC + rank + i, buffer_size); ucc_status_t status = ucc_mem_map(team.ctx, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh); if (status != UCC_OK) { if (status == UCC_ERR_NOT_SUPPORTED || status == UCC_ERR_NOT_IMPLEMENTED) { test_skip = TEST_SKIP_NOT_SUPPORTED; return; } UCC_CHECK(status); } if (!memh) { std::cerr << "Rank " << rank << ": Memory handle is NULL in stress test" << std::endl; return; } if (memh_size == 0) { std::cerr << "Rank " << rank << ": Memory handle size is 0 in stress test" << std::endl; return; } /* Store memh for cleanup */ memhs.push_back(memh); /* Verify data integrity */ UCC_CHECK(check()); } /* Cleanup all memory handles */ for (auto &memh : memhs) { UCC_CHECK(ucc_mem_unmap(&memh)); } memhs.clear(); /* Final verification */ UCC_CHECK(check()); } std::string str() override { return std::string("mem_map_stress") + " team=" + team_str(team.type) + " buffer_size=" + std::to_string(buffer_size) + " iterations=" + std::to_string(num_iterations) + " mem_type=" + std::to_string(mem_type); } static std::shared_ptr init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params); }; class TestMemMapMultiSize : public TestCase { private: std::vector buffer_sizes; std::vector test_buffers; std::vector memhs; public: TestMemMapMultiSize(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_BARRIER, params) { size_t i; /* Test different buffer sizes */ buffer_sizes = {1024, 4096, 65536, 1024 * 1024}; if (skip_reduce(test_max_size < buffer_sizes.back(), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } test_buffers.resize(buffer_sizes.size()); memhs.resize(buffer_sizes.size()); /* Allocate buffers */ for (i = 0; i < buffer_sizes.size(); i++) { test_buffers[i] = ucc_malloc(buffer_sizes[i], "test buffer"); UCC_MALLOC_CHECK(test_buffers[i]); /* Initialize with size-specific pattern */ memset(test_buffers[i], 0xDD + i, buffer_sizes[i]); } } ~TestMemMapMultiSize() { for (auto memh : memhs) { if (memh) { ucc_mem_unmap(&memh); } } for (auto buf : test_buffers) { if (buf) { ucc_free(buf); } } } ucc_status_t set_input(int iter_persistent = 0) override { size_t i; int rank; MPI_Comm_rank(team.comm, &rank); /* Initialize all buffers with rank and iteration specific data */ for (i = 0; i < test_buffers.size(); i++) { memset(test_buffers[i], 0xDD + rank + iter_persistent, buffer_sizes[i]); } return UCC_OK; } ucc_status_t check() override { size_t i; size_t j; int rank; MPI_Comm_rank(team.comm, &rank); for (i = 0; i < test_buffers.size(); i++) { unsigned char *buf = (unsigned char *)test_buffers[i]; unsigned char expected = 0xDD + rank; for (j = 0; j < buffer_sizes[i]; j++) { if (buf[j] != expected) { return UCC_ERR_INVALID_PARAM; } } } return UCC_OK; } void run(bool triggered) override { ucc_mem_map_params_t map_params; ucc_mem_map_t segment; int rank; size_t i; ucc_mem_map_mem_h memh; size_t memh_size; MPI_Comm_rank(team.comm, &rank); /* Test memory mapping with different buffer sizes */ for (i = 0; i < buffer_sizes.size(); i++) { segment.address = test_buffers[i]; segment.len = buffer_sizes[i]; map_params.segments = &segment; map_params.n_segments = 1; memh = nullptr; memh_size = 0; ucc_status_t status = ucc_mem_map(team.ctx, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh); if (status != UCC_OK) { if (status == UCC_ERR_NOT_SUPPORTED || status == UCC_ERR_NOT_IMPLEMENTED) { test_skip = TEST_SKIP_NOT_SUPPORTED; return; } UCC_CHECK(status); } if (!memh) { std::cerr << "Rank " << rank << ": Memory handle is NULL in multi-size test" << std::endl; return; } if (memh_size == 0) { std::cerr << "Rank " << rank << ": Memory handle size is 0 in multi-size test" << std::endl; return; } /* Store memh for cleanup */ memhs[i] = memh; /* Verify data integrity */ UCC_CHECK(check()); } /* Cleanup all memory handles */ for (auto &memh : memhs) { UCC_CHECK(ucc_mem_unmap(&memh)); } /* Final verification */ UCC_CHECK(check()); } std::string str() override { return std::string("mem_map_multi_size") + " team=" + team_str(team.type) + " num_sizes=" + std::to_string(buffer_sizes.size()) + " mem_type=" + std::to_string(mem_type); } static std::shared_ptr init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params); }; // Factory functions for creating test instances std::shared_ptr TestMemMapExport::init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params) { return std::make_shared(_team, params); } std::shared_ptr TestMemMapImport::init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params) { return std::make_shared(_team, params); } std::shared_ptr TestMemMapStress::init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params) { return std::make_shared(_team, params); } std::shared_ptr TestMemMapMultiSize::init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params) { return std::make_shared(_team, params); } openucx-ucc-ec0bc8a/test/mpi/test_mpi.h0000664000175000017500000004056115133731560020451 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef TEST_MPI_H #define TEST_MPI_H #include #include #include #include #include #include #include extern "C" { #include "utils/ucc_malloc.h" } BEGIN_C_DECLS #include "components/mc/ucc_mc.h" #include "core/ucc_team.h" #include "utils/ucc_math.h" END_C_DECLS #ifdef HAVE_CUDA #include #include #endif #ifdef HAVE_HIP #include #endif #define STR(x) #x #define UCC_CHECK(_call) \ if (UCC_OK != (_call)) { \ std::cerr << "*** UCC TEST FAIL: " << STR(_call) << "\n"; \ MPI_Abort(MPI_COMM_WORLD, -1); \ } #define UCC_MALLOC_CHECK(_obj) \ if (!(_obj)) { \ std::cerr << "*** UCC MALLOC FAIL \n"; \ MPI_Abort(MPI_COMM_WORLD, -1); \ } #define UCC_CHECK_SKIP(_call, _skip_cause) \ { \ ucc_status_t status; \ status = (_call); \ if(UCC_ERR_NOT_SUPPORTED == status) { \ _skip_cause = TEST_SKIP_NOT_SUPPORTED; \ } else if (UCC_ERR_NOT_IMPLEMENTED == status) { \ _skip_cause = TEST_SKIP_NOT_IMPLEMENTED; \ } else if (UCC_OK != status) { \ std::cerr << "*** UCC TEST FAIL: " << STR(_call) << "\n"; \ MPI_Abort(MPI_COMM_WORLD, -1); \ } \ } #define TEST_UCC_RANK_BUF_SIZE_MAX (8*1024*1024) extern int test_rand_seed; #define UCC_ALLOC_COPY_BUF(_new_buf, _new_mtype, _old_buf, _old_mtype, _size) \ { \ UCC_CHECK(ucc_mc_alloc(&(_new_buf), _size, _new_mtype)); \ UCC_CHECK(ucc_mc_memcpy(_new_buf->addr, _old_buf, _size, _new_mtype, \ _old_mtype)); \ } #ifdef HAVE_CUDA #define CUDA_CHECK(_call) { \ cudaError_t cuda_err = (_call); \ if (cudaSuccess != (cuda_err)) { \ std::cerr << "*** UCC TEST FAIL: " << STR(_call) << ": " \ << cudaGetErrorString(cuda_err) << "\n" ; \ MPI_Abort(MPI_COMM_WORLD, -1); \ } \ } #endif #ifdef HAVE_HIP #define HIP_CHECK(_call) { \ hipError_t hip_err = (_call); \ if (hipSuccess != (hip_err)) { \ std::cerr << "*** UCC TEST FAIL: " << STR(_call) << ": " \ << hipGetErrorString(hip_err) << "\n" ; \ MPI_Abort(MPI_COMM_WORLD, -1); \ } \ } #endif #define UCC_TEST_N_PERSISTENT 4 #define UCC_TEST_N_MEM_SEGMENTS 3 #define UCC_TEST_MEM_SEGMENT_SIZE (1 << 21) typedef enum { MEM_SEND_SEGMENT, MEM_RECV_SEGMENT, MEM_WORK_SEGMENT } ucc_test_mem_segments; typedef enum { TEAM_WORLD, TEAM_REVERSE, TEAM_SPLIT_HALF, TEAM_SPLIT_ODD_EVEN, TEAM_LAST } ucc_test_mpi_team_t; typedef enum { ROOT_SINGLE, ROOT_RANDOM, ROOT_ALL } ucc_test_mpi_root_t; typedef enum { TEST_FLAG_VSIZE_32BIT, TEST_FLAG_VSIZE_64BIT } ucc_test_vsize_flag_t; typedef enum { TEST_SKIP_NONE, TEST_SKIP_NOT_SUPPORTED, TEST_SKIP_NOT_IMPLEMENTED, TEST_SKIP_MEM_LIMIT, TEST_SKIP_LAST } test_skip_cause_t; #if defined(HAVE_CUDA) || defined(HAVE_HIP) typedef enum { TEST_SET_DEV_NONE, TEST_SET_DEV_LRANK, TEST_SET_DEV_LRANK_ROUND } test_set_gpu_device_t; #endif static inline const char* skip_str(test_skip_cause_t s) { switch(s) { case TEST_SKIP_MEM_LIMIT: return "maximum buffer size reached"; case TEST_SKIP_NOT_SUPPORTED: return "not supported"; case TEST_SKIP_NOT_IMPLEMENTED: return "not implemented"; default: return "unknown"; } } static inline const char* team_str(ucc_test_mpi_team_t t) { switch(t) { case TEAM_WORLD: return "world"; case TEAM_REVERSE: return "reverse"; case TEAM_SPLIT_HALF: return "half"; case TEAM_SPLIT_ODD_EVEN: return "odd_even"; default: break; } return NULL; } typedef struct ucc_test_mpi_data { int local_node_rank; } ucc_test_mpi_data_t; extern ucc_test_mpi_data_t ucc_test_mpi_data; int init_test_mpi_data(void); #if defined(HAVE_CUDA) || defined(HAVE_HIP) void set_gpu_device(test_set_gpu_device_t set_device); #endif int ucc_coll_inplace_supported(ucc_coll_type_t c); int ucc_coll_is_rooted(ucc_coll_type_t c); bool ucc_coll_has_datatype(ucc_coll_type_t c); typedef struct ucc_test_team { #ifdef HAVE_CUDA ucc_ee_h cuda_ee; cudaStream_t cuda_stream; #endif ucc_test_mpi_team_t type; MPI_Comm comm; ucc_team_h team; ucc_context_h ctx; ucc_test_team(ucc_test_mpi_team_t _type, MPI_Comm _comm, ucc_team_h _team, ucc_context_h _ctx) : type(_type), comm(_comm), team(_team), ctx(_ctx) { #ifdef HAVE_CUDA cuda_stream = nullptr; cuda_ee = nullptr; #endif }; #ifdef HAVE_CUDA ucc_status_t get_cuda_ee(ucc_ee_h *ee) { ucc_ee_params_t ee_params; if (!cuda_ee) { CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream, cudaStreamNonBlocking)); ee_params.ee_type = UCC_EE_CUDA_STREAM; ee_params.ee_context_size = sizeof(cudaStream_t); ee_params.ee_context = cuda_stream; UCC_CHECK(ucc_ee_create(team, &ee_params, &cuda_ee)); } *ee = cuda_ee; return UCC_OK; } void free_cuda_ee() { if (cuda_ee) { UCC_CHECK(ucc_ee_destroy(cuda_ee)); CUDA_CHECK(cudaStreamDestroy(cuda_stream)); } } #endif ucc_status_t get_ee(ucc_ee_type_t ee_type, ucc_ee_h *ee) { switch (ee_type) { #ifdef HAVE_CUDA case UCC_EE_CUDA_STREAM: return get_cuda_ee(ee); #endif default: return UCC_ERR_NOT_SUPPORTED; } } void free_ee() { #ifdef HAVE_CUDA free_cuda_ee(); #endif } } ucc_test_team_t; struct TestCaseParams { size_t msgsize; bool inplace; bool persistent; bool local_registration; ucc_datatype_t dt; ucc_reduction_op_t op; ucc_memory_type_t mt; size_t max_size; int root; void **buffers; ucc_test_vsize_flag_t count_bits; ucc_test_vsize_flag_t displ_bits; }; class TestCase { protected: ucc_test_team_t team; ucc_memory_type_t mem_type; int root; size_t msgsize; bool inplace; bool persistent; bool local_registration; ucc_coll_req_h req; ucc_mem_map_mem_h src_memh; size_t src_memh_size; ucc_mem_map_mem_h dst_memh; size_t dst_memh_size; ucc_mc_buffer_header_t *sbuf_mc_header, *rbuf_mc_header; void *sbuf; void *rbuf; void *check_buf; MPI_Request progress_request; uint8_t progress_buf[1]; size_t test_max_size; ucc_datatype_t dt; int iter_persistent; public: ucc_coll_args_t args; void mpi_progress(void); test_skip_cause_t test_skip; static std::shared_ptr init_single( ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params); static std::vector> init( ucc_test_team_t &_team, ucc_coll_type_t _type, int num_tests, TestCaseParams params); TestCase(ucc_test_team_t &_team, ucc_coll_type_t ct, TestCaseParams params); virtual ~TestCase(); virtual void run(bool triggered); virtual ucc_status_t set_input(int iter_persistent = 0) = 0; virtual ucc_status_t check() = 0; virtual std::string str(); virtual ucc_status_t test(); void wait(); void tc_progress_ctx(); test_skip_cause_t skip_reduce(test_skip_cause_t cause, MPI_Comm comm); test_skip_cause_t skip_reduce(int skip_cond, test_skip_cause_t cause, MPI_Comm comm); }; typedef std::tuple ucc_test_mpi_result_t; class UccTestMpi { ucc_thread_mode_t tm; ucc_context_h ctx; ucc_context_h onesided_ctx; ucc_lib_h lib; int nt; ucc_lib_h onesided_lib; bool inplace; bool persistent; ucc_test_mpi_root_t root_type; int root_value; int iterations; bool verbose; void * onesided_buffers[3]; size_t test_max_size; bool triggered; bool local_registration; void create_team(ucc_test_mpi_team_t t, bool is_onesided = false); void destroy_team(ucc_test_team_t &team); ucc_team_h create_ucc_team(MPI_Comm comm, bool is_onesided = false); std::vector msgsizes; std::vector mtypes; std::vector dtypes; std::vector ops; std::vector colls; std::vector gen_roots(ucc_test_team_t &team); std::vector counts_vsize; std::vector displs_vsize; std::vector exec_tests( std::vector> tcs, bool triggered, bool persistent); public: std::vector teams; std::vector onesided_teams; void run_all_at_team(ucc_test_team_t &team, std::vector &rst); std::vector results; UccTestMpi(int argc, char *argv[], ucc_thread_mode_t tm, int is_local, bool with_onesided); ~UccTestMpi(); void set_msgsizes(size_t min, size_t max, size_t power); void set_dtypes(std::vector &_dtypes); void set_colls(std::vector &_colls); void set_iter(int iter); void set_verbose(bool verbose); void set_ops(std::vector &_ops); void set_mtypes(std::vector &_mtypes); void set_inplace(bool _inplace) { inplace = _inplace; } void set_persistent(bool _persistent) { persistent = _persistent; } void set_triggered(bool _triggered) { triggered = _triggered; } void set_local_registration(bool _local_registration) { local_registration = _local_registration; } void set_count_vsizes(std::vector &_counts_vsize); void set_displ_vsizes(std::vector &_displs_vsize); void run_all(bool is_onesided = false); void set_root(ucc_test_mpi_root_t _root_type, int _root_value) { root_type = _root_type; root_value = _root_value; }; void set_max_size(size_t _max_size) { test_max_size = _max_size; } void set_num_tests(int num_tests) { nt = num_tests; } void create_teams(std::vector &test_teams, bool is_onesided = false); void progress_ctx() { ucc_context_progress(ctx); if (onesided_ctx) { ucc_context_progress(onesided_ctx); } } }; class TestAllgather : public TestCase { public: TestAllgather(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); }; class TestAllgatherv : public TestCase { int *counts; int *displacements; public: TestAllgatherv(ucc_test_team_t &team, TestCaseParams ¶ms); ~TestAllgatherv(); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check() override; }; class TestAllreduce : public TestCase { ucc_reduction_op_t op; public: TestAllreduce(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); std::string str(); }; class TestAlltoall : public TestCase { bool is_onesided; public: TestAlltoall(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); }; class TestAlltoallv : public TestCase { size_t sncounts; size_t rncounts; int *scounts; int *sdispls; int *rcounts; int *rdispls; ucc_count_t *scounts64; ucc_count_t *sdispls64; ucc_count_t *rcounts64; ucc_count_t *rdispls64; ucc_test_vsize_flag_t count_bits; ucc_test_vsize_flag_t displ_bits; template void * mpi_counts_to_ucc(int *mpi_counts, size_t _ncount); public: TestAlltoallv(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); std::string str(); ~TestAlltoallv(); }; class TestBarrier : public TestCase { ucc_status_t status; public: TestBarrier(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); std::string str(); void run(bool triggered); ucc_status_t test(); }; class TestBcast : public TestCase { public: TestBcast(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); }; class TestGather : public TestCase { public: TestGather(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); }; class TestGatherv : public TestCase { uint32_t *counts; uint32_t *displacements; public: TestGatherv(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); ~TestGatherv(); }; class TestReduce : public TestCase { ucc_reduction_op_t op; public: TestReduce(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); std::string str(); }; class TestReduceScatter : public TestCase { ucc_reduction_op_t op; public: TestReduceScatter(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ~TestReduceScatter(); ucc_status_t check(); std::string str(); }; class TestReduceScatterv : public TestCase { ucc_reduction_op_t op; int * counts; public: TestReduceScatterv(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ~TestReduceScatterv(); ucc_status_t check(); std::string str(); }; class TestScatter : public TestCase { public: TestScatter(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); }; class TestScatterv : public TestCase { uint32_t *counts; uint32_t *displacements; public: TestScatterv(ucc_test_team_t &team, TestCaseParams ¶ms); ucc_status_t set_input(int iter_persistent = 0) override; ucc_status_t check(); ~TestScatterv(); }; void init_buffer(void *buf, size_t count, ucc_datatype_t dt, ucc_memory_type_t mt, int value, int offset = 0); ucc_status_t compare_buffers(void *rst, void *expected, size_t count, ucc_datatype_t dt, ucc_memory_type_t mt); ucc_status_t divide_buffer(void *expected, size_t divider, size_t count, ucc_datatype_t dt); #endif openucx-ucc-ec0bc8a/test/mpi/main.cc0000664000175000017500000006332015133731560017705 0ustar alastairalastair/** * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Advanced Micro Devices, Inc. 2023. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ #include #include #include #include #include #include "test_mpi.h" int test_rand_seed = -1; static size_t test_max_size = TEST_UCC_RANK_BUF_SIZE_MAX; ucc_test_mpi_data_t ucc_test_mpi_data; static std::vector colls = { UCC_COLL_TYPE_BARRIER, UCC_COLL_TYPE_BCAST, UCC_COLL_TYPE_REDUCE, UCC_COLL_TYPE_ALLREDUCE, UCC_COLL_TYPE_ALLGATHER, UCC_COLL_TYPE_ALLGATHERV, UCC_COLL_TYPE_ALLTOALL, UCC_COLL_TYPE_ALLTOALLV, UCC_COLL_TYPE_REDUCE_SCATTER, UCC_COLL_TYPE_REDUCE_SCATTERV, UCC_COLL_TYPE_GATHER, UCC_COLL_TYPE_GATHERV, UCC_COLL_TYPE_SCATTER, UCC_COLL_TYPE_SCATTERV}; static std::vector onesided_colls = { UCC_COLL_TYPE_ALLTOALL, UCC_COLL_TYPE_ALLTOALLV}; static std::vector mtypes = { UCC_MEMORY_TYPE_HOST}; static std::vector dtypes = { UCC_DT_INT16, UCC_DT_INT32, UCC_DT_INT64, UCC_DT_UINT16, UCC_DT_UINT32, UCC_DT_UINT64, UCC_DT_FLOAT32, UCC_DT_FLOAT64, UCC_DT_FLOAT64_COMPLEX}; static std::vector ops = { UCC_OP_SUM, UCC_OP_MAX, UCC_OP_AVG}; static std::vector teams = { TEAM_WORLD, TEAM_REVERSE, TEAM_SPLIT_HALF, TEAM_SPLIT_ODD_EVEN}; static std::vector counts_vsize = { TEST_FLAG_VSIZE_32BIT, TEST_FLAG_VSIZE_64BIT}; static std::vector displs_vsize = { TEST_FLAG_VSIZE_32BIT, TEST_FLAG_VSIZE_64BIT}; static size_t msgrange[3] = { 8, (1ULL << 21), 8}; static std::vector inplace = {false}; static std::vector persistent = {false}; static std::vector triggered = {false}; static std::vector local_reg = {false}; static ucc_test_mpi_root_t root_type = ROOT_RANDOM; static int root_value = 10; static ucc_thread_mode_t thread_mode = UCC_THREAD_SINGLE; static int iterations = 1; static int show_help = 0; static int num_tests = 1; static bool has_onesided = true; static bool verbose = false; #if defined(HAVE_CUDA) || defined(HAVE_HIP) extern test_set_gpu_device_t test_gpu_set_device; #endif static std::vector str_split(const char *value, const char *delimiter) { std::vector rst; std::string str(value); std::string delim(delimiter); size_t pos = 0; std::string token; while ((pos = str.find(delim)) != std::string::npos) { token = str.substr(0, pos); rst.push_back(token); str.erase(0, pos + delim.length()); } rst.push_back(str); return rst; } void print_help() { std::cout << "-c, --colls \n\tlist of collectives: " "barrier, allreduce, allgather, allgatherv, bcast, alltoall, alltoallv " "reduce, reduce_scatter, reduce_scatterv, gather, gatherv, scatter, scatterv\n\n" "-t, --teams \n\tlist of teams: world,half,reverse,odd_even\n\n" "-M, --mtypes \n\tlist of mtypes: host,cuda,cudaManaged,rocm\n\n" "-d, --dtypes \n\tlist of dtypes: (u)int8(16,32,64),float32(64,128),float32(64,128)_complex\n\n" "-o, --ops \n\tlist of ops:sum,prod,max,min,land,lor,lxor,band,bor,bxor\n\n" "-I, --inplace \n\t0 - no inplace, 1 - inplace, 2 - both\n\n" "-P, --persistent \n\t0 - no persistent, 1 - persistent, 2 - both\n\n" "-m, --msgsize \n\tmesage sizes range\n\n" "-r, --root \n\ttype of root selection: single:, random:, all\n\n" "-s, --seed \n\tuser defined random seed\n\n" "-Z, --max_size \n\tmaximum send/recv buffer allocation size\n\n" "-C, --count_bits \n\tlist of counts bits: 32,64 (alltoallv only)\n\n" "-D, --displ_bits \n\tlist of displacements bits: 32,64 (alltoallv only)\n\n" "-S, --set_device \n\t0 - don't set, 1 - cuda_device = local_rank, 2 - cuda_device = local_rank % cuda_device_count\n\n" "-N, --num_tests \n\tnumber of tests to run in parallel\n\n" "-O, --onesided \n\t0 - no onesided tests, 1 - onesided tests\n\n" "-i, --iter \n\tnumber of iterations each test cases is executed\n\n" "-T, --thread-multiple\n\tenable multi-threaded testing\n\n" "-v, --verbose\n\tlog all test cases\n\n" "--triggered \n\t0 - use post, 1 - use triggered post, 2 - both\n\n" "--local_reg \n\t0 - no local registration, 1 - local registration, 2 - both\n\n" "-h, --help\n\tShow help\n"; } template static std::vector process_arg(const char *value, T (*str_to_type)(std::string value)) { std::vector rst; for (auto &c : str_split(value, ",")) { rst.push_back(str_to_type(c)); } return rst; } static ucc_test_mpi_team_t team_str_to_type(std::string team) { if (team == "world") { return TEAM_WORLD; } else if (team == "half") { return TEAM_SPLIT_HALF; } else if (team == "odd_even") { return TEAM_SPLIT_ODD_EVEN; } else if (team == "reverse") { return TEAM_REVERSE; } throw std::string("incorrect team type: ") + team; } static std::string team_type_to_str(ucc_test_mpi_team_t team) { switch (team) { case TEAM_WORLD: return "world"; case TEAM_SPLIT_HALF: return "half"; case TEAM_SPLIT_ODD_EVEN: return "odd_even"; case TEAM_REVERSE: return "reverse"; default: break; } throw std::string("incorrect team type: "); } static ucc_coll_type_t coll_str_to_type(std::string coll) { if (coll == "barrier") { return UCC_COLL_TYPE_BARRIER; } else if (coll == "allreduce") { return UCC_COLL_TYPE_ALLREDUCE; } else if (coll == "allgather") { return UCC_COLL_TYPE_ALLGATHER; } else if (coll == "allgatherv") { return UCC_COLL_TYPE_ALLGATHERV; } else if (coll == "bcast") { return UCC_COLL_TYPE_BCAST; } else if (coll == "reduce") { return UCC_COLL_TYPE_REDUCE; } else if (coll == "alltoall") { return UCC_COLL_TYPE_ALLTOALL; } else if (coll == "alltoallv") { return UCC_COLL_TYPE_ALLTOALLV; } else if (coll == "reduce_scatter") { return UCC_COLL_TYPE_REDUCE_SCATTER; } else if (coll == "reduce_scatterv") { return UCC_COLL_TYPE_REDUCE_SCATTERV; } else if (coll == "reduce") { return UCC_COLL_TYPE_REDUCE; } else if (coll == "gather") { return UCC_COLL_TYPE_GATHER; } else if (coll == "gatherv") { return UCC_COLL_TYPE_GATHERV; } else if (coll == "scatter") { return UCC_COLL_TYPE_SCATTER; } else if (coll == "scatterv") { return UCC_COLL_TYPE_SCATTERV; } else { throw std::string("incorrect coll type: ") + coll; } } static ucc_memory_type_t mtype_str_to_type(std::string mtype) { if (mtype == "host") { return UCC_MEMORY_TYPE_HOST; } else if (mtype == "cuda") { return UCC_MEMORY_TYPE_CUDA; } else if (mtype == "cudaManaged") { return UCC_MEMORY_TYPE_CUDA_MANAGED; } else if (mtype == "rocm") { return UCC_MEMORY_TYPE_ROCM; } throw std::string("incorrect memory type: ") + mtype; } static ucc_datatype_t dtype_str_to_type(std::string dtype) { if (dtype == "int8") { return UCC_DT_INT8; } else if (dtype == "uint8") { return UCC_DT_UINT8; } else if (dtype == "int16") { return UCC_DT_INT16; } else if (dtype == "uint16") { return UCC_DT_UINT16; } else if (dtype == "int32") { return UCC_DT_INT32; } else if (dtype == "uint32") { return UCC_DT_UINT32; } else if (dtype == "int64") { return UCC_DT_INT64; } else if (dtype == "uint64") { return UCC_DT_UINT64; } else if (dtype == "float32") { return UCC_DT_FLOAT32; } else if (dtype == "float64") { return UCC_DT_FLOAT64; } else if (dtype == "float128") { return UCC_DT_FLOAT128; } else if (dtype == "bfloat16") { return UCC_DT_BFLOAT16; } else if (dtype == "float16") { return UCC_DT_FLOAT16; } else if (dtype == "int128") { return UCC_DT_INT128; } else if (dtype == "uint128") { return UCC_DT_UINT128; } else if (dtype == "float32_complex") { return UCC_DT_FLOAT32_COMPLEX; } else if (dtype == "float64_complex") { return UCC_DT_FLOAT64_COMPLEX; } else if (dtype == "float128_complex") { return UCC_DT_FLOAT128_COMPLEX; } throw std::string("incorrect dtype: ") + dtype; } static ucc_reduction_op_t op_str_to_type(std::string op) { if (op == "sum") { return UCC_OP_SUM; } else if (op == "prod") { return UCC_OP_PROD; } else if (op == "max") { return UCC_OP_MAX; } else if (op == "min") { return UCC_OP_MIN; } else if (op == "land") { return UCC_OP_LAND; } else if (op == "lor") { return UCC_OP_LOR; } else if (op == "lxor") { return UCC_OP_LXOR; } else if (op == "band") { return UCC_OP_BAND; } else if (op == "bor") { return UCC_OP_BOR; } else if (op == "bxor") { return UCC_OP_BXOR; } else if (op == "avg") { return UCC_OP_AVG; } throw std::string("incorrect op: ") + op; } static ucc_test_vsize_flag_t bits_str_to_type(std::string vsize) { if (vsize == "32") { return TEST_FLAG_VSIZE_32BIT; } else if (vsize == "64") { return TEST_FLAG_VSIZE_64BIT; } throw std::string("incorrect vsize") + vsize; } static void process_msgrange(const char *arg) { auto tokens = str_split(arg, ":"); try { if (tokens.size() == 1) { msgrange[0] = std::stol(tokens[0]); msgrange[1] = msgrange[0]; msgrange[2] = 0; } else if (tokens.size() >= 2) { msgrange[0] = std::stol(tokens[0]); msgrange[1] = std::stol(tokens[1]); msgrange[2] = 2; if (tokens.size() == 3) { msgrange[2] = std::stol(tokens[2]); } } } catch (std::exception &e) { throw std::string("incorrect msgrange: ") + arg; } } static void process_inplace(const char *arg) { int value = std::stoi(arg); switch(value) { case 0: inplace = {false}; return; case 1: inplace = {true}; return; case 2: inplace = {false, true}; return; default: break; } throw std::string("incorrect inplace: ") + arg; } static void process_persistent(const char *arg) { int value = std::stoi(arg); switch(value) { case 0: persistent = {false}; return; case 1: persistent = {true}; return; case 2: persistent = {false, true}; return; default: break; } throw std::string("incorrect persistent: ") + arg; } static void process_triggered(const char *arg) { int value = std::stoi(arg); switch(value) { case 0: triggered = {false}; return; case 1: triggered = {true}; return; case 2: triggered = {false, true}; return; default: break; } throw std::string("incorrect triggered: ") + arg; } static void process_local_reg(const char *arg) { int value = std::stoi(arg); switch(value) { case 0: local_reg = {false}; return; case 1: local_reg = {true}; return; case 2: local_reg = {false, true}; return; default: break; } throw std::string("incorrect local_reg: ") + arg; } static void process_root(const char *arg) { auto tokens = str_split(arg, ":"); std::string root_type_str = tokens[0]; if (root_type_str == "all") { if (tokens.size() != 1) { goto err; } root_type = ROOT_ALL; } else if (root_type_str == "random") { if (tokens.size() != 2) { goto err; } root_type = ROOT_RANDOM; root_value = std::atoi(tokens[1].c_str()); } else if (root_type_str == "single") { if (tokens.size() != 2) { goto err; } root_type = ROOT_SINGLE; root_value = std::atoi(tokens[1].c_str()); } else { goto err; } return; err: throw std::string("incorrect root: ") + arg; } int init_rand_seed(int user_seed) { int rank, seed; MPI_Comm_rank(MPI_COMM_WORLD, &rank); if (0 > user_seed) { if (0 == rank) { seed = time(NULL) % 32768; } } else { seed = user_seed; } MPI_Bcast(&seed, 1, MPI_INT, 0, MPI_COMM_WORLD); if (0 != rank) { seed += rank; } return seed; } void print_info() { int world_rank; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); if (world_rank) { return; } std::cout << "===== UCC MPI TEST INFO =======" << std::endl; std::cout <<"seed: " << std::to_string(test_rand_seed) << std::endl; std::cout <<"collectives: "; for (const auto &c : colls) { std::cout << ucc_coll_type_str(c); if (c != colls.back()) { std::cout << ", "; } else { std::cout << std::endl; } } std::cout <<"data types: "; for (const auto &d : dtypes) { std::cout << ucc_datatype_str(d); if (d != dtypes.back()) { std::cout << ", "; } else { std::cout << std::endl; } } std::cout <<"memory types: "; for (const auto &m : mtypes) { std::cout << ucc_mem_type_str(m); if (m != mtypes.back()) { std::cout << ", "; } else { std::cout << std::endl; } } std::cout <<"teams: "; for (const auto &t : teams) { std::cout << team_type_to_str(t); if (t != teams.back()) { std::cout << ", "; } else { std::cout << std::endl; } } } void ProcessArgs(int argc, char** argv) { const char *const short_opts = "c:t:m:d:o:M:I:P:N:r:s:C:D:i:Z:G:ThvS:O:L:"; const option long_opts[] = { {"colls", required_argument, nullptr, 'c'}, {"teams", required_argument, nullptr, 't'}, {"mtypes", required_argument, nullptr, 'M'}, {"dtypes", required_argument, nullptr, 'd'}, {"ops", required_argument, nullptr, 'o'}, {"msgsize", required_argument, nullptr, 'm'}, {"inplace", required_argument, nullptr, 'I'}, {"persistent", required_argument, nullptr, 'P'}, {"root", required_argument, nullptr, 'r'}, {"seed", required_argument, nullptr, 's'}, {"max_size", required_argument, nullptr, 'Z'}, {"count_bits", required_argument, nullptr, 'C'}, {"displ_bits", required_argument, nullptr, 'D'}, {"iter", required_argument, nullptr, 'i'}, {"thread-multiple", no_argument, nullptr, 'T'}, {"num_tests", required_argument, nullptr, 'N'}, {"triggered", required_argument, nullptr, 'G'}, {"verbose", no_argument, nullptr, 'v'}, #if defined(HAVE_CUDA) || defined(HAVE_HIP) {"set_device", required_argument, nullptr, 'S'}, #endif {"onesided", required_argument, nullptr, 'O'}, {"local_registration", required_argument, nullptr, 'L'}, {"help", no_argument, nullptr, 'h'}, {nullptr, no_argument, nullptr, 0} }; while (true) { const auto opt = getopt_long(argc, argv, short_opts, long_opts, nullptr); if (-1 == opt) break; switch (opt) { case 'c': colls = process_arg(optarg, coll_str_to_type); break; case 't': teams = process_arg(optarg, team_str_to_type); break; case 'M': mtypes = process_arg(optarg, mtype_str_to_type); break; case 'd': dtypes = process_arg(optarg, dtype_str_to_type); break; case 'o': ops = process_arg(optarg, op_str_to_type); break; case 'm': process_msgrange(optarg); break; case 'I': process_inplace(optarg); break; case 'P': process_persistent(optarg); break; case 'G': process_triggered(optarg); break; case 'L': process_local_reg(optarg); break; case 'r': process_root(optarg); break; case 's': test_rand_seed = std::stoi(optarg); break; case 'Z': test_max_size = std::stoi(optarg); break; case 'C': counts_vsize = process_arg(optarg, bits_str_to_type); break; case 'D': displs_vsize = process_arg(optarg, bits_str_to_type); break; case 'T': thread_mode = UCC_THREAD_MULTIPLE; break; case 'i': iterations = std::stoi(optarg); break; case 'N': num_tests = std::stoi(optarg); break; #if defined(HAVE_CUDA) || defined(HAVE_HIP) case 'S': test_gpu_set_device = (test_set_gpu_device_t)std::stoi(optarg); break; #endif case 'O': has_onesided = std::stoi(optarg); break; case 'v': verbose = true; break; case 'h': show_help = 1; break; case '?': // Unrecognized option default: throw std::string("unrecognized option"); } } } int main(int argc, char *argv[]) { int failed = 0; int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4]; std::chrono::steady_clock::time_point begin; int size, required, provided, completed, rank; UccTestMpi *test; MPI_Request req; std::string err; begin = std::chrono::steady_clock::now(); memset(total_done_skipped_failed, 0, sizeof(total_done_skipped_failed)); try { ProcessArgs(argc, argv); } catch (const std::string &s) { failed = 1; err = s; } required = (thread_mode == UCC_THREAD_SINGLE) ? MPI_THREAD_SINGLE : MPI_THREAD_MULTIPLE; MPI_Init_thread(&argc, &argv, required, &provided); if (provided != required) { std::cerr << "could not initialize MPI in thread multiple\n"; return 1; } MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); if (!err.empty() || show_help) { if (rank == 0) { std::cerr << "ParseArgs error:" << err << "\n\n"; print_help(); } goto mpi_exit; } if (size < 2) { std::cerr << "test requires at least 2 ranks\n"; goto mpi_exit; } init_test_mpi_data(); #if defined(HAVE_CUDA) || defined(HAVE_HIP) set_gpu_device(test_gpu_set_device); #endif test = new UccTestMpi(argc, argv, thread_mode, 0, has_onesided); for (auto &m : mtypes) { if (UCC_MEMORY_TYPE_HOST != m && UCC_OK != ucc_mc_available(m)) { std::cerr << "requested memory type " << ucc_memory_type_names[m] << " is not supported " << std::endl; failed = -1; goto test_exit; } } test->create_teams(teams); if (has_onesided) { test->create_teams(teams, true); } test->set_verbose(verbose); test->set_iter(iterations); test->set_num_tests(num_tests); test->set_colls(colls); test->set_dtypes(dtypes); test->set_mtypes(mtypes); test->set_ops(ops); test->set_root(root_type, root_value); test->set_count_vsizes(counts_vsize); test->set_displ_vsizes(displs_vsize); test->set_msgsizes(msgrange[0],msgrange[1],msgrange[2]); test->set_max_size(test_max_size); test_rand_seed = init_rand_seed(test_rand_seed); print_info(); for (auto inpl : inplace) { for (auto pers : persistent) { for (auto trig: triggered) { for (auto lr : local_reg) { test->set_triggered(trig); test->set_inplace(inpl); test->set_persistent(pers); test->set_local_registration(lr); test->run_all(); } } } } if (has_onesided) { std::vector os_colls(onesided_colls.size()); std::vector::iterator it_start; std::sort(colls.begin(), colls.end()); std::sort(onesided_colls.begin(), onesided_colls.end()); it_start = std::set_intersection( colls.begin(), colls.end(), onesided_colls.begin(), onesided_colls.end(), os_colls.begin()); os_colls.resize(it_start - os_colls.begin()); test->set_colls(os_colls); for (auto inpl : inplace) { for (auto pers : persistent) { test->set_triggered(false); test->set_inplace(inpl); test->set_persistent(pers); test->run_all(true); } } } std::cout << std::flush; for (auto s : test->results) { int coll_num = ucc_ilog2(std::get<0>(s)); switch(std::get<1>(s)) { case UCC_OK: total_done_skipped_failed[coll_num][1]++; break; case UCC_ERR_NOT_IMPLEMENTED: case UCC_ERR_LAST: total_done_skipped_failed[coll_num][2]++; break; default: total_done_skipped_failed[coll_num][3]++; } total_done_skipped_failed[coll_num][0]++; } MPI_Iallreduce(MPI_IN_PLACE, total_done_skipped_failed, sizeof(total_done_skipped_failed)/sizeof(int), MPI_INT, MPI_MAX, MPI_COMM_WORLD, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); test->progress_ctx(); } while(!completed); if (0 == rank) { std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); ucc_coll_type_t coll_type; int num_all = 0, num_skipped = 0, num_done =0, num_failed = 0; std::ios iostate(nullptr); iostate.copyfmt(std::cout); std::cout << "\n===== UCC MPI TEST REPORT =====\n" << std::setw(22) << std::left << "collective" << std::setw(10) << std::right << "tests" << std::setw(10) << std::right << "passed" << std::setw(10) << std::right << "failed" << std::setw(10) << std::right << "skipped" << std::endl; for (coll_type = (ucc_coll_type_t)1; coll_type < UCC_COLL_TYPE_LAST; coll_type = (ucc_coll_type_t)(coll_type << 1)) { int coll_num = ucc_ilog2(coll_type); if (total_done_skipped_failed[coll_num][0] == 0) { continue; } num_all += total_done_skipped_failed[coll_num][0]; num_done += total_done_skipped_failed[coll_num][1]; num_skipped += total_done_skipped_failed[coll_num][2]; num_failed += total_done_skipped_failed[coll_num][3]; std::cout << std::setw(22) << std::left << ucc_coll_type_str(coll_type) << std::setw(10) << std::right << total_done_skipped_failed[coll_num][0] << std::setw(10) << std::right << total_done_skipped_failed[coll_num][1] << std::setw(10) << std::right << total_done_skipped_failed[coll_num][3] << std::setw(10) << std::right << total_done_skipped_failed[coll_num][2] << std::endl; } std::cout << " \n===== UCC MPI TEST SUMMARY =====\n" << "total tests: " << num_all << "\n" << "passed: " << num_done << "\n" << "skipped: " << num_skipped << "\n" << "failed: " << num_failed << "\n" << "elapsed: " << std::chrono::duration_cast(end - begin).count() << "s" << std::endl; std::cout.copyfmt(iostate); /* check if all tests have been skipped */ if (num_all == num_skipped) { std::cout << "\n All tests have been skipped, indicating most likely " "a problem\n"; failed = 1; } if (num_failed != 0) { failed = 1; } } test_exit: delete test; mpi_exit: MPI_Finalize(); return failed; } int init_test_mpi_data(void) { MPI_Comm local_comm; MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm); MPI_Comm_rank(local_comm, &ucc_test_mpi_data.local_node_rank); MPI_Comm_free(&local_comm); return 0; } openucx-ucc-ec0bc8a/test/mpi/test_allgather.cc0000664000175000017500000000751215133731560021764 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "core/ucc_dt.h" #include "test_mpi.h" #include "mpi_util.h" #include "ucc/api/ucc.h" TestAllgather::TestAllgather(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_ALLGATHER, params) { int rank, size; size_t dt_size, single_rank_count; dt = params.dt; dt_size = ucc_dt_size(dt); single_rank_count = msgsize / dt_size; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize*size), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize * size, mem_type)); rbuf = rbuf_mc_header->addr; check_buf = ucc_malloc(msgsize * size, "check buf"); UCC_MALLOC_CHECK(check_buf); if (!inplace) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); sbuf = sbuf_mc_header->addr; args.src.info.buffer = sbuf; args.src.info.count = single_rank_count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } args.dst.info.buffer = rbuf; args.dst.info.count = single_rank_count * size; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; if (local_registration) { ucc_mem_map_t segments[1]; ucc_mem_map_params_t mem_map_params; mem_map_params.n_segments = 1; mem_map_params.segments = segments; if (!inplace) { mem_map_params.segments[0].address = args.src.info.buffer; mem_map_params.segments[0].len = args.src.info.count * ucc_dt_size(args.src.info.datatype); UCC_CHECK(ucc_mem_map(team.ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &src_memh_size, &src_memh)); args.src_memh.local_memh = src_memh; args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH; } mem_map_params.segments[0].address = args.dst.info.buffer; mem_map_params.segments[0].len = args.dst.info.count * ucc_dt_size(args.dst.info.datatype); UCC_CHECK(ucc_mem_map(team.ctx, UCC_MEM_MAP_MODE_EXPORT, &mem_map_params, &dst_memh_size, &dst_memh)); args.dst_memh.local_memh = dst_memh; args.mask |= UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH; } UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestAllgather::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t single_rank_count = msgsize / dt_size; size_t single_rank_size = single_rank_count * dt_size; int rank; void *buf; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (inplace) { buf = PTR_OFFSET(rbuf, rank * single_rank_size); } else { buf = sbuf; } init_buffer(buf, single_rank_count, dt, mem_type, rank * (iter_persistent + 1)); return UCC_OK; } ucc_status_t TestAllgather::check() { size_t dt_size, single_rank_count; int size, i; MPI_Comm_size(team.comm, &size); single_rank_count = args.dst.info.count / size; dt_size = ucc_dt_size(dt); for (i = 0; i < size; i++) { init_buffer(PTR_OFFSET(check_buf, i * single_rank_count * dt_size), single_rank_count, dt, UCC_MEMORY_TYPE_HOST, i * (iter_persistent + 1)); } return compare_buffers(rbuf, check_buf, single_rank_count * size, dt, mem_type); } openucx-ucc-ec0bc8a/test/mpi/test_alltoallv.cc0000664000175000017500000001656715133731560022025 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include #include #include "test_mpi.h" #include "mpi_util.h" template void * TestAlltoallv::mpi_counts_to_ucc(int *mpi_counts, size_t _ncount) { void *ucc_counts = (T*)malloc(sizeof(T) * _ncount); for (size_t i = 0; i < _ncount; i++) { ((T*)ucc_counts)[i] = mpi_counts[i]; } return ucc_counts; } TestAlltoallv::TestAlltoallv(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_ALLTOALLV, params) { std::default_random_engine eng; size_t dt_size, count; int rank, nprocs, rank_count; bool is_onesided; void *work_buf; dt = params.dt; dt_size = ucc_dt_size(dt); count = msgsize / dt_size; sncounts = 0; rncounts = 0; scounts = NULL; sdispls = NULL; rcounts = NULL; rdispls = NULL; scounts64 = NULL; sdispls64 = NULL; rcounts64 = NULL; rdispls64 = NULL; count_bits = params.count_bits; displ_bits = params.displ_bits; is_onesided = (params.buffers != NULL); work_buf = NULL; std::uniform_int_distribution urd(count / 2, count); eng.seed(test_rand_seed); MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &nprocs); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize * nprocs), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } args.mask = UCC_COLL_ARGS_FIELD_FLAGS; args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; if (is_onesided) { args.mask |= UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; args.flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; } if (count_bits == TEST_FLAG_VSIZE_64BIT) { args.flags |= UCC_COLL_ARGS_FLAG_COUNT_64BIT; } if (displ_bits == TEST_FLAG_VSIZE_64BIT) { args.flags |= UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; } scounts = (int*)malloc(sizeof(*scounts) * nprocs); sdispls = (int*)malloc(sizeof(*sdispls) * nprocs); rcounts = (int*)malloc(sizeof(*rcounts) * nprocs); rdispls = (int*)malloc(sizeof(*rdispls) * nprocs); for (auto i = 0; i < nprocs; i++) { rank_count = urd(eng); scounts[i] = rank_count; } MPI_Alltoall((void*)scounts, 1, MPI_INT, (void*)rcounts, 1, MPI_INT, team.comm); sncounts = 0; rncounts = 0; for (auto i = 0; i < nprocs; i++) { assert((size_t)rcounts[i] <= count); sdispls[i] = sncounts; rdispls[i] = rncounts; sncounts += scounts[i]; rncounts += rcounts[i]; } if ((test_max_size < (sncounts * dt_size)) || (test_max_size < (rncounts * dt_size))) { test_skip = TEST_SKIP_MEM_LIMIT; } if (TEST_SKIP_NONE != skip_reduce(test_skip, team.comm)) { return; } check_buf = ucc_malloc(rncounts * dt_size, "check buf"); UCC_MALLOC_CHECK(check_buf); if (!is_onesided) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, sncounts * dt_size, mem_type)); UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, rncounts * dt_size, mem_type)); sbuf = sbuf_mc_header->addr; rbuf = rbuf_mc_header->addr; } else { sbuf = params.buffers[MEM_SEND_SEGMENT]; rbuf = params.buffers[MEM_RECV_SEGMENT]; work_buf = params.buffers[MEM_WORK_SEGMENT]; args.global_work_buffer = work_buf; } args.src.info_v.buffer = sbuf; args.src.info_v.datatype = dt; args.src.info_v.mem_type = mem_type; args.dst.info_v.buffer = rbuf; args.dst.info_v.datatype = dt; args.dst.info_v.mem_type = mem_type; if (TEST_FLAG_VSIZE_64BIT == count_bits || TEST_FLAG_VSIZE_64BIT == displ_bits) { if (msgsize % 64 != 0) { test_skip = TEST_SKIP_NOT_SUPPORTED; } } else { if (msgsize % 32 != 0) { test_skip = TEST_SKIP_NOT_SUPPORTED; } } if (TEST_SKIP_NONE != skip_reduce(test_skip, team.comm)) { return; } if (TEST_FLAG_VSIZE_64BIT == count_bits) { args.src.info_v.counts = scounts64 = (ucc_count_t*)mpi_counts_to_ucc(scounts, nprocs); args.dst.info_v.counts = rcounts64 = (ucc_count_t*)mpi_counts_to_ucc(rcounts, nprocs); } else { args.src.info_v.counts = (ucc_count_t*)scounts; args.dst.info_v.counts = (ucc_count_t*)rcounts; } if (TEST_FLAG_VSIZE_64BIT == displ_bits) { args.src.info_v.displacements = sdispls64 = (ucc_aint_t*)mpi_counts_to_ucc(sdispls, nprocs); args.dst.info_v.displacements = rdispls64 = (ucc_aint_t*)mpi_counts_to_ucc(rdispls, nprocs); } else { args.src.info_v.displacements = (ucc_aint_t*)sdispls; args.dst.info_v.displacements = (ucc_aint_t*)rdispls; } if (is_onesided) { MPI_Datatype datatype; size_t disp_size; void *ldisp; int alltoall_status; if (TEST_FLAG_VSIZE_64BIT == displ_bits) { datatype = MPI_LONG; disp_size = sizeof(uint64_t); } else { datatype = MPI_INT; disp_size = sizeof(uint32_t); } ldisp = ucc_calloc(nprocs, disp_size, "displacements"); UCC_MALLOC_CHECK(ldisp); alltoall_status = MPI_Alltoall(args.dst.info_v.displacements, 1, datatype, ldisp, 1, datatype, team.comm); if (MPI_SUCCESS != alltoall_status) { std::cerr << "*** MPI ALLTOALL FAILED" << std::endl; MPI_Abort(MPI_COMM_WORLD, -1); } args.dst.info_v.displacements = (ucc_aint_t *)ldisp; } UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestAlltoallv::set_input(int iter_persistent) { int rank; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); init_buffer(sbuf, sncounts, dt, mem_type, rank * (iter_persistent + 1)); return UCC_OK; } TestAlltoallv::~TestAlltoallv() { free(scounts); free(sdispls); free(rcounts); free(rdispls); free(scounts64); free(sdispls64); free(rcounts64); free(rdispls64); } ucc_status_t TestAlltoallv::check() { MPI_Request req; int i, size, rank, completed; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); MPI_Ialltoall(sdispls, 1, MPI_INT, scounts, 1, MPI_INT, team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); for (i = 0; i < size; i++) { init_buffer(PTR_OFFSET(check_buf, rdispls[i] * ucc_dt_size(dt)), rcounts[i], dt, UCC_MEMORY_TYPE_HOST, i * (iter_persistent + 1), scounts[i]); } return compare_buffers(rbuf, check_buf, rncounts, dt, mem_type); } std::string TestAlltoallv::str() { return TestCase::str() + " counts=" + (count_bits == TEST_FLAG_VSIZE_64BIT ? "64" : "32") + " displs=" + (displ_bits == TEST_FLAG_VSIZE_64BIT ? "64" : "32"); } openucx-ucc-ec0bc8a/test/mpi/test_gather.cc0000664000175000017500000000654115133731560021274 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestGather::TestGather(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_GATHER, params) { int rank, size; size_t dt_size, single_rank_count; dt = params.dt; dt_size = ucc_dt_size(dt); single_rank_count = msgsize / dt_size; root = params.root; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize * size), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } if (rank == root) { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize * size, mem_type)); rbuf = rbuf_mc_header->addr; if (inplace) { sbuf_mc_header = NULL; sbuf = NULL; } else { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); sbuf = sbuf_mc_header->addr; } } else { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); sbuf = sbuf_mc_header->addr; rbuf_mc_header = NULL; rbuf = NULL; } check_buf = ucc_malloc(msgsize*size, "check buf"); UCC_MALLOC_CHECK(check_buf); args.root = root; if (rank == root) { args.dst.info.buffer = rbuf; args.dst.info.count = single_rank_count * size; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; if (!inplace) { args.src.info.buffer = sbuf; args.src.info.count = single_rank_count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } } else { args.src.info.buffer = sbuf; args.src.info.count = single_rank_count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestGather::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t single_rank_count = msgsize / dt_size; size_t single_rank_size = single_rank_count * dt_size; int rank; void *buf; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (rank == root) { if (inplace) { buf = PTR_OFFSET(rbuf, rank * single_rank_size); } else { buf = sbuf; } } else { buf = sbuf; } init_buffer(buf, single_rank_count, dt, mem_type, rank * (iter_persistent + 1)); return UCC_OK; } ucc_status_t TestGather::check() { int size, rank, i; size_t dt_size, single_rank_count; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); if (rank != root) { return UCC_OK; } dt_size = ucc_dt_size(dt); single_rank_count = msgsize / dt_size; for (i = 0; i < size; i++) { init_buffer(PTR_OFFSET(check_buf, i * single_rank_count * dt_size), single_rank_count, dt, UCC_MEMORY_TYPE_HOST, i * (iter_persistent + 1)); } return compare_buffers(rbuf, check_buf, single_rank_count * size, dt, mem_type); } openucx-ucc-ec0bc8a/test/mpi/mpi_util.h0000664000175000017500000000406715133731560020450 0ustar alastairalastair/** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #ifndef MPI_UTIL_H #define MPI_UTIL_H #include "test_mpi.h" static inline MPI_Datatype ucc_dt_to_mpi(ucc_datatype_t dt) { switch (dt) { case UCC_DT_INT8: return MPI_INT8_T; case UCC_DT_UINT8: return MPI_UINT8_T; case UCC_DT_INT16: return MPI_INT16_T; case UCC_DT_UINT16: return MPI_UINT16_T; case UCC_DT_INT32: return MPI_INT32_T; case UCC_DT_UINT32: return MPI_UINT32_T; case UCC_DT_FLOAT32: return MPI_FLOAT; case UCC_DT_INT64: return MPI_INT64_T; case UCC_DT_UINT64: return MPI_UINT64_T; case UCC_DT_FLOAT64: return MPI_DOUBLE; case UCC_DT_FLOAT128: return MPI_LONG_DOUBLE; case UCC_DT_FLOAT32_COMPLEX: return MPI_C_FLOAT_COMPLEX; case UCC_DT_FLOAT64_COMPLEX: return MPI_C_DOUBLE_COMPLEX; case UCC_DT_FLOAT128_COMPLEX: return MPI_C_LONG_DOUBLE_COMPLEX; case UCC_DT_FLOAT16: case UCC_DT_INT128: case UCC_DT_UINT128: case UCC_DT_BFLOAT16: default: std::cerr << "Unsupported dt\n"; MPI_Abort(MPI_COMM_WORLD, -1); } return MPI_DATATYPE_NULL; } static inline MPI_Op ucc_op_to_mpi(ucc_reduction_op_t op) { switch(op) { case UCC_OP_SUM: return MPI_SUM; case UCC_OP_PROD: return MPI_PROD; case UCC_OP_MAX: return MPI_MAX; case UCC_OP_MIN: return MPI_MIN; case UCC_OP_LAND: return MPI_LAND; case UCC_OP_LOR: return MPI_LOR; case UCC_OP_LXOR: return MPI_LXOR; case UCC_OP_BAND: return MPI_BAND; case UCC_OP_BOR: return MPI_BOR; case UCC_OP_BXOR: return MPI_BXOR; case UCC_OP_MAXLOC: return MPI_MAXLOC; case UCC_OP_MINLOC: return MPI_MINLOC; default: std::cerr << "Unsupported op\n"; MPI_Abort(MPI_COMM_WORLD, -1); } return MPI_OP_NULL; } MPI_Comm create_mpi_comm(ucc_test_mpi_team_t t); #endif openucx-ucc-ec0bc8a/test/mpi/test_gatherv.cc0000664000175000017500000001067115133731560021461 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" static void fill_counts_and_displacements(int size, int count, uint32_t *counts, uint32_t *displs) { int bias = count / 2; int i; counts[0] = count - bias; displs[0] = 0; for (i = 1; i < size - 1; i++) { if (i % 2 == 0) { counts[i] = count - bias; } else { counts[i] = count + bias; } displs[i] = displs[i - 1] + counts[i - 1]; } if (size % 2 == 0) { counts[size - 1] = count + bias; } else { counts[size - 1] = count; } displs[size - 1] = displs[size - 2] + counts[size - 2]; } TestGatherv::TestGatherv(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_GATHERV, params) { int rank, size; size_t dt_size, count; dt = params.dt; dt_size = ucc_dt_size(dt); count = msgsize / dt_size; root = params.root; counts = NULL; displacements = NULL; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize * size), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } counts = (uint32_t *) ucc_malloc(size * sizeof(uint32_t), "counts buf"); displacements = (uint32_t *) ucc_malloc(size * sizeof(uint32_t), "displacements buf"); fill_counts_and_displacements(size, count, counts, displacements); if (rank == root) { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, count * size * dt_size, mem_type)); rbuf = rbuf_mc_header->addr; if (inplace) { sbuf_mc_header = NULL; sbuf = NULL; } else { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, counts[rank] * dt_size, mem_type)); sbuf = sbuf_mc_header->addr; } } else { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, counts[rank] * dt_size, mem_type)); sbuf = sbuf_mc_header->addr; rbuf_mc_header = NULL; rbuf = NULL; } check_buf = ucc_malloc(count * size * dt_size, "check buf"); UCC_MALLOC_CHECK(check_buf); args.root = root; if (rank == root) { args.dst.info_v.buffer = rbuf; args.dst.info_v.counts = (ucc_count_t*)counts; args.dst.info_v.displacements = (ucc_aint_t*)displacements; args.dst.info_v.datatype = dt; args.dst.info_v.mem_type = mem_type; if (!inplace) { args.src.info.buffer = sbuf; args.src.info.count = counts[rank]; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } } else { args.src.info.buffer = sbuf; args.src.info.count = counts[rank]; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestGatherv::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); int rank; void *buf; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (rank == root) { if (inplace) { buf = PTR_OFFSET(rbuf, displacements[rank] * dt_size); } else { buf = sbuf; } } else { buf = sbuf; } init_buffer(buf, counts[rank], dt, mem_type, rank * (iter_persistent + 1)); return UCC_OK; } TestGatherv::~TestGatherv() { if (counts) { ucc_free(counts); } if (displacements) { ucc_free(displacements); } } ucc_status_t TestGatherv::check() { size_t count = msgsize / ucc_dt_size(dt); int size, rank, i; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); if (rank != root) { return UCC_OK; } for (i = 0; i < size; i++) { init_buffer(PTR_OFFSET(check_buf, displacements[i] * ucc_dt_size(dt)), counts[i], dt, UCC_MEMORY_TYPE_HOST, i * (iter_persistent + 1)); } return (rank != root) ? UCC_OK : compare_buffers(rbuf, check_buf, count * size, dt, mem_type); } openucx-ucc-ec0bc8a/test/mpi/test_bcast.cc0000664000175000017500000000343615133731560021116 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestBcast::TestBcast(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_BCAST, params) { int rank, size; size_t dt_size, count; dt = params.dt; dt_size = ucc_dt_size(dt); count = msgsize / dt_size; root = params.root; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (skip_reduce(test_max_size < msgsize, TEST_SKIP_MEM_LIMIT, team.comm)) { return; } check_buf = ucc_malloc(msgsize, "check buf"); UCC_MALLOC_CHECK(check_buf); UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); sbuf = sbuf_mc_header->addr; args.src.info.buffer = sbuf; args.src.info.count = count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; args.root = root; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestBcast::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; int rank; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (rank == root) { init_buffer(sbuf, count, dt, mem_type, rank * (iter_persistent + 1)); } return UCC_OK; } ucc_status_t TestBcast::check() { size_t count = args.src.info.count; int rank; MPI_Comm_rank(team.comm, &rank); if (rank == root) { return UCC_OK; } init_buffer(check_buf, count, dt, UCC_MEMORY_TYPE_HOST, root * (iter_persistent + 1)); return compare_buffers(sbuf, check_buf, count, dt, mem_type); } openucx-ucc-ec0bc8a/test/mpi/test_barrier.cc0000664000175000017500000000464615133731560021454 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" TestBarrier::TestBarrier(ucc_test_team_t &team, TestCaseParams ¶ms) : TestCase(team, UCC_COLL_TYPE_BARRIER, params) { status = UCC_OK; UCC_CHECK(ucc_collective_init(&args, &req, team.team)); } ucc_status_t TestBarrier::set_input(int iter_persistent) /* NOLINT */ { return UCC_OK; } ucc_status_t TestBarrier::check() { return status; } std::string TestBarrier::str() { return std::string("tc=")+std::string(ucc_coll_type_str(args.coll_type)) + std::string(" team=") + std::string(team_str(team.type)); } ucc_status_t TestBarrier::test() { return UCC_OK; } void TestBarrier::run(bool triggered) { int completed = 1; int *recv = NULL; int rank, size; MPI_Request rreq; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (0 == rank) { recv = new int[size]; } srand(rank+1); /* random sleep 0 - 500 ms */ usleep((rand() % 500)*1000); for (int i = 0; i < size; i++) { if (0 == rank) { recv[i] = -1; completed = 0; MPI_Irecv(&recv[i], 1, MPI_INT, MPI_ANY_SOURCE, 123, team.comm, &rreq); } if (rank == i) { MPI_Ssend(&rank, 1, MPI_INT, 0, 123, team.comm); } UCC_CHECK(ucc_collective_post(req)); do { status = ucc_collective_test(req); if (status < 0) { std::cerr << "failure in collective test\n"; MPI_Abort(MPI_COMM_WORLD, -1); } ucc_context_progress(team.ctx); if (0 == rank && !completed) { MPI_Test(&rreq, &completed, MPI_STATUS_IGNORE); } mpi_progress(); } while(UCC_OK != status); while (0 == rank && !completed) { MPI_Test(&rreq, &completed, MPI_STATUS_IGNORE); mpi_progress(); } if (!persistent && i < size - 1) { UCC_CHECK(ucc_collective_finalize(req)); UCC_CHECK(ucc_collective_init(&args, &req, team.team)); } } MPI_Barrier(team.comm); if (0 == rank) { for (int i = 0; i < size; i++) { if (recv[i] != i) { status = UCC_ERR_NO_MESSAGE; break; } } } if (0 == rank) { delete[] recv; } } openucx-ucc-ec0bc8a/test/mpi/test_allgatherv.cc0000664000175000017500000000725615133731560022157 0ustar alastairalastair/** * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" static void fill_counts_and_displacements(int size, int count, int *counts, int *displs) { int bias = count / 2; int i; counts[0] = count - bias; displs[0] = 0; for (i = 1; i < size - 1; i++) { if (i % 2 == 0) { counts[i] = count - bias; } else { counts[i] = count + bias; } displs[i] = displs[i - 1] + counts[i - 1]; } if (size % 2 == 0) { counts[size - 1] = count + bias; } else { counts[size - 1] = count; } displs[size - 1] = displs[size - 2] + counts[size - 2]; } TestAllgatherv::TestAllgatherv(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_ALLGATHERV, params) { int rank, size; size_t dt_size, count; dt = params.dt; dt_size = ucc_dt_size(dt); count = msgsize / dt_size; counts = NULL; displacements = NULL; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize * size), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } counts = (int *) ucc_malloc(size * sizeof(uint32_t), "counts buf"); UCC_MALLOC_CHECK(counts); displacements = (int *) ucc_malloc(size * sizeof(uint32_t), "displacements buf"); UCC_MALLOC_CHECK(displacements); UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize * size, mem_type)); rbuf = rbuf_mc_header->addr; check_buf = ucc_malloc(msgsize * size, "check buf"); UCC_MALLOC_CHECK(check_buf); fill_counts_and_displacements(size, count, counts, displacements); args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; if (!inplace) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, counts[rank] * dt_size, mem_type)); sbuf = sbuf_mc_header->addr; args.src.info.buffer = sbuf; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; args.src.info.count = counts[rank]; } args.dst.info_v.buffer = rbuf; args.dst.info_v.counts = (ucc_count_t*)counts; args.dst.info_v.displacements = (ucc_aint_t*)displacements; args.dst.info_v.datatype = dt; args.dst.info_v.mem_type = mem_type; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestAllgatherv::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); int rank; void *buf; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (inplace) { buf = PTR_OFFSET(rbuf, displacements[rank] * dt_size); } else { buf = sbuf; } init_buffer(buf, counts[rank], dt, mem_type, rank * (iter_persistent + 1)); return UCC_OK; } TestAllgatherv::~TestAllgatherv() { if (counts) { ucc_free(counts); } if (displacements) { ucc_free(displacements); } } ucc_status_t TestAllgatherv::check() { int total_count = 0; int size, i; MPI_Comm_size(team.comm, &size); for (i = 0 ; i < size; i++) { total_count += counts[i]; } for (i = 0; i < size; i++) { init_buffer(PTR_OFFSET(check_buf, displacements[i] * ucc_dt_size(dt)), counts[i], dt, UCC_MEMORY_TYPE_HOST, i * (iter_persistent + 1)); } return compare_buffers(rbuf, check_buf, total_count, dt, mem_type); } openucx-ucc-ec0bc8a/test/mpi/test_alltoall.cc0000664000175000017500000000723315133731560021625 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestAlltoall::TestAlltoall(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_ALLTOALL, params) { void* work_buf = nullptr; int rank, nprocs; size_t dt_size, single_rank_count; dt = params.dt; dt_size = ucc_dt_size(dt); single_rank_count = msgsize / dt_size; is_onesided = (params.buffers != nullptr); MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &nprocs); if (TEST_SKIP_NONE != skip_reduce(test_max_size < (msgsize * nprocs), TEST_SKIP_MEM_LIMIT, team.comm)) { return; } if (!is_onesided) { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize * nprocs, mem_type)); rbuf = rbuf_mc_header->addr; } else { sbuf = params.buffers[MEM_SEND_SEGMENT]; rbuf = params.buffers[MEM_RECV_SEGMENT]; work_buf = params.buffers[MEM_WORK_SEGMENT]; } check_buf = ucc_malloc(msgsize * nprocs, "check buf"); UCC_MALLOC_CHECK(check_buf); if (!inplace) { if (!is_onesided) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize * nprocs, mem_type)); sbuf = sbuf_mc_header->addr; } } if (is_onesided) { args.mask |= UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; args.flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; args.global_work_buffer = work_buf; } if (!inplace) { args.src.info.buffer = sbuf; args.src.info.count = single_rank_count * nprocs; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } args.dst.info.buffer = rbuf; args.dst.info.count = single_rank_count * nprocs; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestAlltoall::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t single_rank_count = msgsize / dt_size; MPI_Request req; void * buf; int rank, nprocs, completed; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &nprocs); if (inplace) { buf = rbuf; } else { buf = sbuf; } init_buffer(buf, single_rank_count * nprocs, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, buf, single_rank_count * nprocs * dt_size, UCC_MEMORY_TYPE_HOST, mem_type)); if (is_onesided && persistent) { MPI_Ibarrier(team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); } return UCC_OK; } ucc_status_t TestAlltoall::check() { int size, rank, i; size_t single_rank_count; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); single_rank_count = args.src.info.count / size; for (i = 0; i < size; i++) { init_buffer(PTR_OFFSET(check_buf, i * single_rank_count * ucc_dt_size(dt)), single_rank_count, dt, UCC_MEMORY_TYPE_HOST, i * (iter_persistent + 1), single_rank_count * rank); } return compare_buffers(rbuf, check_buf, single_rank_count * size, dt, mem_type); } openucx-ucc-ec0bc8a/test/mpi/test_case.cc0000664000175000017500000001557215133731560020741 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" std::vector> TestCase::init(ucc_test_team_t &_team, ucc_coll_type_t _type, int num_tests, TestCaseParams params) { std::vector> tcs; for (int i = 0; i < num_tests; i++) { auto tc = init_single(_team, _type, params); if (!tc) { tcs.clear(); return tcs; } tcs.push_back(tc); } return tcs; } std::shared_ptr TestCase::init_single(ucc_test_team_t &_team, ucc_coll_type_t _type, TestCaseParams params) { switch(_type) { case UCC_COLL_TYPE_ALLGATHER: return std::make_shared(_team, params); case UCC_COLL_TYPE_ALLGATHERV: return std::make_shared(_team, params); case UCC_COLL_TYPE_ALLREDUCE: return std::make_shared(_team, params); case UCC_COLL_TYPE_ALLTOALL: return std::make_shared(_team, params); case UCC_COLL_TYPE_ALLTOALLV: return std::make_shared(_team, params); case UCC_COLL_TYPE_BARRIER: return std::make_shared(_team, params); case UCC_COLL_TYPE_BCAST: return std::make_shared(_team, params); case UCC_COLL_TYPE_GATHER: return std::make_shared(_team, params); case UCC_COLL_TYPE_GATHERV: return std::make_shared(_team, params); case UCC_COLL_TYPE_REDUCE: return std::make_shared(_team, params); case UCC_COLL_TYPE_REDUCE_SCATTER: return std::make_shared(_team, params); case UCC_COLL_TYPE_REDUCE_SCATTERV: return std::make_shared(_team, params); case UCC_COLL_TYPE_SCATTER: return std::make_shared(_team, params); case UCC_COLL_TYPE_SCATTERV: return std::make_shared(_team, params); default: std::cerr << "collective type is not supported" << std::endl; break; } return NULL; } void TestCase::run(bool triggered) { if (triggered) { ucc_ee_h ee = nullptr; ucc_ev_t comp_ev, *post_ev; ucc_ee_type_t ee_type; if (mem_type == UCC_MEMORY_TYPE_CUDA) { ee_type = UCC_EE_CUDA_STREAM; } else { UCC_CHECK(UCC_ERR_NOT_SUPPORTED); } UCC_CHECK(team.get_ee(ee_type, &ee)); comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; comp_ev.ev_context = nullptr; comp_ev.ev_context_size = 0; comp_ev.req = req; UCC_CHECK(ucc_collective_triggered_post(ee, &comp_ev)); UCC_CHECK(ucc_ee_get_event(ee, &post_ev)); UCC_CHECK(ucc_ee_ack_event(ee, post_ev)); } else { UCC_CHECK(ucc_collective_post(req)); } } ucc_status_t TestCase::test() { return ucc_collective_test(req); } void TestCase::wait() { ucc_status_t status; do { mpi_progress(); status = test(); if (status < 0) { std::cerr << "error during coll test: " << ucc_status_string(status) << " ("<addr; if (inplace) { rbuf_mc_header = NULL; rbuf = NULL; } else { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, counts[rank] * dt_size, mem_type)); rbuf = rbuf_mc_header->addr; } } else { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, counts[rank] * dt_size, mem_type)); rbuf = rbuf_mc_header->addr; sbuf_mc_header = NULL; sbuf = NULL; } check_buf = ucc_malloc(count * size * dt_size, "check buf"); UCC_MALLOC_CHECK(check_buf); args.root = root; if (rank == root) { args.src.info_v.buffer = sbuf; args.src.info_v.counts = (ucc_count_t*)counts; args.src.info_v.displacements = (ucc_aint_t*)displacements; args.src.info_v.datatype = dt; args.src.info_v.mem_type = mem_type; if (!inplace) { args.dst.info.buffer = rbuf; args.dst.info.count = counts[rank]; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; } } else { args.dst.info.buffer = rbuf; args.dst.info.count = counts[rank]; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; } UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestScatterv::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; int rank, size; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); if (rank == root) { init_buffer(sbuf, count * size, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, sbuf, count * size * dt_size, UCC_MEMORY_TYPE_HOST, mem_type)); } return UCC_OK; } TestScatterv::~TestScatterv() { if (counts) { ucc_free(counts); } if (displacements) { ucc_free(displacements); } } ucc_status_t TestScatterv::check() { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; MPI_Datatype mpi_dt = ucc_dt_to_mpi(dt); MPI_Request req; int size, rank, completed; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); MPI_Iscatterv(check_buf, (int *)counts, (int *)displacements, mpi_dt, (rank == root) ? MPI_IN_PLACE : check_buf, counts[rank], mpi_dt, root, team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); if (rank == root) { if (inplace) { return compare_buffers(sbuf, check_buf, count * size, dt, mem_type); } else { return compare_buffers( rbuf, PTR_OFFSET(check_buf, displacements[rank] * dt_size), counts[rank], dt, mem_type); } } else { return compare_buffers(rbuf, check_buf, counts[rank], dt, mem_type); } } openucx-ucc-ec0bc8a/test/mpi/test_allreduce.cc0000664000175000017500000000622315133731560021757 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestAllreduce::TestAllreduce(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_ALLREDUCE, params) { size_t dt_size = ucc_dt_size(params.dt); size_t count = msgsize/dt_size; int rank; MPI_Comm_rank(team.comm, &rank); op = params.op; dt = params.dt; if (skip_reduce(test_max_size < msgsize, TEST_SKIP_MEM_LIMIT, team.comm)) { return; } UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; check_buf = ucc_malloc(msgsize, "check buf"); UCC_MALLOC_CHECK(check_buf); if (!inplace) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); sbuf = sbuf_mc_header->addr; args.src.info.buffer = sbuf; args.src.info.count = count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } else { args.src.info.buffer = NULL; args.src.info.count = SIZE_MAX; args.src.info.datatype = (ucc_datatype_t)-1; args.src.info.mem_type = UCC_MEMORY_TYPE_UNKNOWN; } args.op = op; args.dst.info.buffer = rbuf; args.dst.info.count = count; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestAllreduce::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; int rank; void *buf; MPI_Comm_rank(team.comm, &rank); if (inplace) { buf = rbuf; } else { buf = sbuf; } init_buffer(buf, count, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, buf, count * dt_size, UCC_MEMORY_TYPE_HOST, mem_type)); return UCC_OK; } ucc_status_t TestAllreduce::check() { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; MPI_Request req; int completed; ucc_status_t status; MPI_Iallreduce(MPI_IN_PLACE, check_buf, count, ucc_dt_to_mpi(dt), op == UCC_OP_AVG ? MPI_SUM : ucc_op_to_mpi(op), team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); if (op == UCC_OP_AVG) { status = divide_buffer(check_buf, team.team->size, count, dt); if (status != UCC_OK) { return status; } } return compare_buffers(rbuf, check_buf, count, dt, mem_type); } std::string TestAllreduce::str() { return std::string("tc=") + ucc_coll_type_str(args.coll_type) + " team=" + team_str(team.type) + " msgsize=" + std::to_string(msgsize) + " inplace=" + (inplace ? "1" : "0") + " persistent=" + (persistent ? "1" : "0") + " dt=" + ucc_datatype_str(dt) + " op=" + ucc_reduction_op_str(op); } openucx-ucc-ec0bc8a/test/mpi/mpi_util.cc0000664000175000017500000000273115133731560020602 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "mpi_util.h" static MPI_Comm create_half_comm() { int world_rank, world_size; MPI_Comm new_comm; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); MPI_Comm_size(MPI_COMM_WORLD, &world_size); MPI_Comm_split(MPI_COMM_WORLD, world_rank < world_size / 2, world_rank, &new_comm); return new_comm; } static MPI_Comm create_odd_even_comm() { int world_rank; MPI_Comm new_comm; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); MPI_Comm_split(MPI_COMM_WORLD, world_rank % 2, world_rank, &new_comm); return new_comm; } static MPI_Comm create_reverse_comm() { int world_rank, world_size; MPI_Comm new_comm; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); MPI_Comm_size(MPI_COMM_WORLD, &world_size); MPI_Comm_split(MPI_COMM_WORLD, 0, world_size - world_rank - 1, &new_comm); return new_comm; } MPI_Comm create_mpi_comm(ucc_test_mpi_team_t t) { MPI_Comm comm = MPI_COMM_NULL; switch(t) { case TEAM_WORLD: MPI_Comm_dup(MPI_COMM_WORLD, &comm); break; case TEAM_REVERSE: comm = create_reverse_comm(); break; case TEAM_SPLIT_HALF: comm = create_half_comm(); break; case TEAM_SPLIT_ODD_EVEN: comm = create_odd_even_comm(); break; default: break; } return comm; } openucx-ucc-ec0bc8a/test/mpi/Makefile.am0000664000175000017500000000237115133731560020505 0ustar alastairalastair# # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # $COPYRIGHT$ # # Additional copyrights may follow # # $HEADER$ # bin_PROGRAMS = ucc_test_mpi ucc_test_mpi_SOURCES = \ test_mpi.cc \ main.cc \ buffer.cc \ mpi_util.cc \ test_case.cc \ test_barrier.cc \ test_allreduce.cc \ test_allgather.cc \ test_allgatherv.cc \ test_bcast.cc \ test_alltoall.cc \ test_alltoallv.cc \ test_reduce_scatter.cc \ test_reduce_scatterv.cc \ test_reduce.cc \ test_gather.cc \ test_gatherv.cc \ test_scatter.cc \ test_scatterv.cc \ test_mem_map.cc CXX=$(MPICXX) LD=$(MPICXX) ucc_test_mpi_CPPFLAGS = $(BASE_CPPFLAGS) ucc_test_mpi_CXXFLAGS = $(BASE_CXXFLAGS) -std=gnu++11 ucc_test_mpi_LDFLAGS = -Wl,--rpath-link=${UCS_LIBDIR} ucc_test_mpi_LDADD = $(UCC_TOP_BUILDDIR)/src/libucc.la if HAVE_CUDA ucc_test_mpi_CPPFLAGS += $(CUDA_CPPFLAGS) ucc_test_mpi_LDFLAGS += $(CUDA_LDFLAGS) ucc_test_mpi_LDADD += $(CUDA_LIBS) endif if HAVE_ROCM ucc_test_mpi_CPPFLAGS += $(HIP_CPPFLAGS) ucc_test_mpi_CXXFLAGS += $(HIP_CXXFLAGS) ucc_test_mpi_LDFLAGS += $(HIP_LDFLAGS) ucc_test_mpi_LDADD += $(HIP_LIBS) endif openucx-ucc-ec0bc8a/test/mpi/test_mpi.cc0000664000175000017500000005443415133731560020613 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" BEGIN_C_DECLS #include "utils/ucc_math.h" END_C_DECLS #include #include #include #include static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen, void *coll_info, void **req) { MPI_Comm comm = (MPI_Comm)(uintptr_t)coll_info; MPI_Request request; MPI_Iallgather(sbuf, msglen, MPI_BYTE, rbuf, msglen, MPI_BYTE, comm, &request); *req = (void *)(uintptr_t)request; return UCC_OK; } static ucc_status_t oob_allgather_test(void *req) { MPI_Request request = (MPI_Request)(uintptr_t)req; int completed; MPI_Test(&request, &completed, MPI_STATUS_IGNORE); return completed ? UCC_OK : UCC_INPROGRESS; } static ucc_status_t oob_allgather_free(void *req) { return UCC_OK; } UccTestMpi::UccTestMpi(int argc, char *argv[], ucc_thread_mode_t _tm, int is_local, bool with_onesided) { ucc_lib_config_h lib_config; ucc_context_config_h ctx_config; int size, rank; char *prev_env; ucc_mem_map_t segments[UCC_TEST_N_MEM_SEGMENTS]; MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); /* Init ucc library */ ucc_lib_params_t lib_params = { .mask = UCC_LIB_PARAM_FIELD_THREAD_MODE, .thread_mode = _tm, /* .coll_types = coll_types, */ }; tm = _tm; //TODO check ucc provided /* Init ucc context for a specified UCC_TEST_TLS */ ucc_context_params_t ctx_params = {}; ucc_context_params_t onesided_ctx_params = {}; if (!is_local) { ctx_params.mask |= UCC_CONTEXT_PARAM_FIELD_OOB; ctx_params.oob.allgather = oob_allgather; ctx_params.oob.req_test = oob_allgather_test; ctx_params.oob.req_free = oob_allgather_free; ctx_params.oob.coll_info = (void*)(uintptr_t)MPI_COMM_WORLD; ctx_params.oob.n_oob_eps = size; ctx_params.oob.oob_ep = rank; if (with_onesided) { onesided_ctx_params = ctx_params; for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) { onesided_buffers[i] = ucc_calloc(UCC_TEST_MEM_SEGMENT_SIZE, size, "onesided buffers"); UCC_MALLOC_CHECK(onesided_buffers[i]); segments[i].address = onesided_buffers[i]; segments[i].len = UCC_TEST_MEM_SEGMENT_SIZE * size; } onesided_ctx_params.mask |= UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS; onesided_ctx_params.mem_params.segments = segments; onesided_ctx_params.mem_params.n_segments = UCC_TEST_N_MEM_SEGMENTS; } } if (!with_onesided) { for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) { onesided_buffers[i] = NULL; } } UCC_CHECK(ucc_lib_config_read(NULL, NULL, &lib_config)); UCC_CHECK(ucc_init(&lib_params, lib_config, &lib)); ucc_lib_config_release(lib_config); UCC_CHECK(ucc_context_config_read(lib, NULL, &ctx_config)); UCC_CHECK(ucc_context_create(lib, &ctx_params, ctx_config, &ctx)); ucc_context_config_release(ctx_config); if (with_onesided) { prev_env = getenv("UCC_TL_UCP_TUNE"); setenv("UCC_TL_UCP_TUNE", "alltoall:0-inf:@onesided#alltoallv:0-inf:@onesided", 1); UCC_CHECK(ucc_lib_config_read(NULL, NULL, &lib_config)); UCC_CHECK(ucc_init(&lib_params, lib_config, &onesided_lib)); ucc_lib_config_release(lib_config); UCC_CHECK(ucc_context_config_read(onesided_lib, NULL, &ctx_config)); UCC_CHECK(ucc_context_create(onesided_lib, &onesided_ctx_params, ctx_config, &onesided_ctx)); ucc_context_config_release(ctx_config); if (prev_env) { putenv(prev_env); } else { unsetenv("UCC_TL_UCP_TUNE"); } } else { onesided_lib = nullptr; onesided_ctx = nullptr; } set_msgsizes(8, ((1ULL) << 21), 8); dtypes = {UCC_DT_INT16, UCC_DT_INT32, UCC_DT_INT64, UCC_DT_UINT16, UCC_DT_UINT32, UCC_DT_UINT64, UCC_DT_FLOAT32, UCC_DT_FLOAT64, UCC_DT_FLOAT128, UCC_DT_FLOAT32_COMPLEX, UCC_DT_FLOAT64_COMPLEX, UCC_DT_FLOAT128_COMPLEX}; ops = {UCC_OP_SUM, UCC_OP_MAX}; colls = {UCC_COLL_TYPE_BARRIER, UCC_COLL_TYPE_ALLREDUCE}; mtypes = {UCC_MEMORY_TYPE_HOST}; inplace = false; persistent = false; root_type = ROOT_RANDOM; root_value = 10; iterations = 1; triggered = false; local_registration = false; } void UccTestMpi::set_iter(int iter) { iterations = iter; } void UccTestMpi::set_verbose(bool _verbose) { verbose = _verbose; } void UccTestMpi::create_teams(std::vector &test_teams, bool is_onesided) { int rank, size; MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &size); for (auto &t : test_teams) { if (size < 4 && (t == TEAM_SPLIT_HALF || t == TEAM_SPLIT_ODD_EVEN)) { if (rank == 0) { std::cout << "size of the world=" << size << " is too small to create team " << team_str(t) << ", skipping ...\n"; } continue; } create_team(t, is_onesided); } } UccTestMpi::~UccTestMpi() { for (auto &t : teams) { destroy_team(t); } for (auto &t : onesided_teams) { destroy_team(t); } if (onesided_buffers[0]) { for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) { ucc_free(onesided_buffers[i]); } UCC_CHECK(ucc_context_destroy(onesided_ctx)); UCC_CHECK(ucc_finalize(onesided_lib)); } UCC_CHECK(ucc_context_destroy(ctx)); UCC_CHECK(ucc_finalize(lib)); } ucc_team_h UccTestMpi::create_ucc_team(MPI_Comm comm, bool is_onesided) { ucc_context_h team_ctx = ctx; int rank, size; ucc_team_h team; ucc_team_params_t team_params; ucc_status_t status; MPI_Comm_rank(comm, &rank); MPI_Comm_size(comm, &size); /* Create UCC TEAM for comm world */ team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | UCC_TEAM_PARAM_FIELD_OOB; team_params.oob.allgather = oob_allgather; team_params.oob.req_test = oob_allgather_test; team_params.oob.req_free = oob_allgather_free; team_params.oob.coll_info = (void*)(uintptr_t)comm; team_params.oob.n_oob_eps = size; team_params.oob.oob_ep = rank; team_params.ep = rank; team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; if (is_onesided) { team_params.mask |= UCC_TEAM_PARAM_FIELD_FLAGS; team_params.flags = UCC_TEAM_FLAG_COLL_WORK_BUFFER; team_ctx = onesided_ctx; } UCC_CHECK(ucc_team_create_post(&team_ctx, 1, &team_params, &team)); MPI_Request req; int tmp; int completed; MPI_Irecv(&tmp, 1, MPI_INT, rank, 123, comm, &req); while (UCC_INPROGRESS == (status = ucc_team_create_test(team))) { ucc_context_progress(team_ctx); MPI_Test(&req, &completed, MPI_STATUS_IGNORE); }; MPI_Send(&tmp, 1, MPI_INT, rank, 123, comm); MPI_Wait(&req, MPI_STATUS_IGNORE); if (status < 0) { std::cerr << "*** UCC TEST FAIL: ucc_team_create_test failed\n"; MPI_Abort(MPI_COMM_WORLD, -1); } return team; } void UccTestMpi::create_team(ucc_test_mpi_team_t t, bool is_onesided) { ucc_team_h team; MPI_Comm comm = create_mpi_comm(t); if (is_onesided) { MPI_Comm comm_dup; MPI_Comm_dup(comm, &comm_dup); team = create_ucc_team(comm_dup, true); onesided_teams.push_back(ucc_test_team_t(t, comm_dup, team, onesided_ctx)); } else { team = create_ucc_team(comm); teams.push_back(ucc_test_team_t(t, comm, team, ctx)); } } void UccTestMpi::destroy_team(ucc_test_team_t &team) { ucc_status_t status; team.free_ee(); while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) {} if (UCC_OK != status) { std::cerr << "ucc_team_destroy failed\n"; } if (team.comm != MPI_COMM_WORLD) { MPI_Comm_free(&team.comm); } } void UccTestMpi::set_msgsizes(size_t min, size_t max, size_t power) { size_t m = min; msgsizes.clear(); while (m < max) { msgsizes.push_back(m); m *= power; } msgsizes.push_back(max); } void UccTestMpi::set_dtypes(std::vector &_dtypes) { dtypes = _dtypes; } void UccTestMpi::set_mtypes(std::vector &_mtypes) { mtypes = _mtypes; } void UccTestMpi::set_colls(std::vector &_colls) { colls = _colls; } void UccTestMpi::set_ops(std::vector &_ops) { ops = _ops; } int ucc_coll_reduce_supported(ucc_reduction_op_t op, ucc_datatype_t dt) { switch (dt) { case UCC_DT_INT8: case UCC_DT_INT16: case UCC_DT_INT32: case UCC_DT_INT64: case UCC_DT_INT128: case UCC_DT_UINT8: case UCC_DT_UINT16: case UCC_DT_UINT32: case UCC_DT_UINT64: case UCC_DT_UINT128: return (op != UCC_OP_AVG); case UCC_DT_FLOAT16: case UCC_DT_FLOAT32: case UCC_DT_FLOAT64: case UCC_DT_BFLOAT16: case UCC_DT_FLOAT128: return (op == UCC_OP_SUM || op == UCC_OP_PROD || op == UCC_OP_MAX || op == UCC_OP_MIN || op == UCC_OP_AVG); case UCC_DT_FLOAT32_COMPLEX: case UCC_DT_FLOAT64_COMPLEX: case UCC_DT_FLOAT128_COMPLEX: return (op == UCC_OP_SUM || op == UCC_OP_PROD || op == UCC_OP_AVG); default: return 0; } } int ucc_coll_inplace_supported(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_BARRIER: case UCC_COLL_TYPE_BCAST: case UCC_COLL_TYPE_FANIN: case UCC_COLL_TYPE_FANOUT: /* remove alltoall [v] from here once it starts supporting inplace */ case UCC_COLL_TYPE_ALLTOALL: case UCC_COLL_TYPE_ALLTOALLV: /**/ return 0; default: return 1; } } bool ucc_coll_triggered_supported(ucc_memory_type_t mt) { if (mt == UCC_MEMORY_TYPE_CUDA) { return true; } return false; } int ucc_coll_is_rooted(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_ALLREDUCE: case UCC_COLL_TYPE_ALLGATHER: case UCC_COLL_TYPE_ALLGATHERV: case UCC_COLL_TYPE_ALLTOALL: case UCC_COLL_TYPE_ALLTOALLV: case UCC_COLL_TYPE_BARRIER: case UCC_COLL_TYPE_REDUCE_SCATTER: case UCC_COLL_TYPE_REDUCE_SCATTERV: return 0; default: return 1; } } bool ucc_coll_has_memtype(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_BARRIER: case UCC_COLL_TYPE_FANIN: case UCC_COLL_TYPE_FANOUT: return false; default: return true; } } bool ucc_coll_has_msgrange(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_BARRIER: case UCC_COLL_TYPE_FANIN: case UCC_COLL_TYPE_FANOUT: return false; default: return true; } } bool ucc_coll_has_datatype(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_BARRIER: case UCC_COLL_TYPE_FANIN: case UCC_COLL_TYPE_FANOUT: return false; default: return true; } } bool ucc_coll_has_op(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_ALLREDUCE: case UCC_COLL_TYPE_REDUCE: case UCC_COLL_TYPE_REDUCE_SCATTER: case UCC_COLL_TYPE_REDUCE_SCATTERV: return true; default: return false; } } bool ucc_coll_has_bits(ucc_coll_type_t c) { switch(c) { case UCC_COLL_TYPE_ALLTOALLV: return true; default: return false; } } void UccTestMpi::set_count_vsizes(std::vector &_counts_vsize) { counts_vsize = _counts_vsize; } void UccTestMpi::set_displ_vsizes(std::vector &_displs_vsize) { displs_vsize = _displs_vsize; } #if defined(HAVE_CUDA) || defined(HAVE_HIP) test_set_gpu_device_t test_gpu_set_device = TEST_SET_DEV_NONE; #endif #if defined(HAVE_CUDA) || defined(HAVE_HIP) void set_gpu_device(test_set_gpu_device_t set_device) { int local_rank = ucc_test_mpi_data.local_node_rank; int gpu_dev_count; int device_id; if (set_device == TEST_SET_DEV_NONE) { return; } #if defined(HAVE_CUDA) CUDA_CHECK(cudaGetDeviceCount(&gpu_dev_count)); #elif defined(HAVE_HIP) HIP_CHECK(hipGetDeviceCount(&gpu_dev_count)); #endif switch (set_device) { case TEST_SET_DEV_LRANK: if(local_rank >= gpu_dev_count) { std::cerr << "*** UCC TEST FAIL: " << "not enough GPU devices on the node to map processes.\n"; MPI_Abort(MPI_COMM_WORLD, -1); } device_id = local_rank; break; case TEST_SET_DEV_LRANK_ROUND: device_id = local_rank % gpu_dev_count; break; case TEST_SET_DEV_NONE: default: return; } #if defined(HAVE_CUDA) CUDA_CHECK(cudaSetDevice(device_id)); // Force CUDA context creation for the device // Without this, cuCtxGetCurrent() in TL/CUDA will return NULL CUDA_CHECK(cudaFree(0)); #elif defined(HAVE_HIP) HIP_CHECK(hipSetDevice(device_id)); // Force HIP context creation for the device HIP_CHECK(hipFree(0)); #endif } #endif std::vector UccTestMpi::exec_tests( std::vector> tcs, bool triggered, bool persistent) { int n_persistent = persistent ? UCC_TEST_N_PERSISTENT : 1; int world_rank, num_done, i; ucc_status_t status; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); std::vector rst; for (i = 0; i < n_persistent; i++) { for (auto tc: tcs) { if (TEST_SKIP_NONE == tc->test_skip) { if (verbose && 0 == world_rank) { if (triggered) { std::cout << "Triggered "<str() << std::endl; } else { std::cout << tc->str() << std::endl; } } if (tc->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS) { MPI_Barrier(MPI_COMM_WORLD); } tc->run(triggered); } else { if (verbose && 0 == world_rank) { std::cout << "SKIPPED: " << skip_str(tc->test_skip) << ": " << tc->str() << " " << std::endl; } rst.push_back(std::make_tuple(tc->args.coll_type, UCC_ERR_LAST)); return rst; } } do { num_done = 0; for (auto tc: tcs) { tc->mpi_progress(); status = tc->test(); if (status < 0) { std::cerr << "error during coll test: " << ucc_status_string(status) << " ("<tc_progress_ctx(); } } while (num_done != tcs.size()); for (auto tc: tcs) { status = tc->check(); tc->set_input(i + 1); if (UCC_OK != status) { std::cerr << "FAILURE in: " << tc->str() << std::endl; } rst.push_back(std::make_tuple(tc->args.coll_type, status)); } } return rst; } void UccTestMpi::run_all_at_team(ucc_test_team_t &team, std::vector &rst) { TestCaseParams params; params.max_size = test_max_size; params.inplace = inplace; params.persistent = persistent; params.local_registration = local_registration; for (auto i = 0; i < iterations; i++) { for (auto &c : colls) { std::vector roots = {0}; std::vector test_memtypes = {UCC_MEMORY_TYPE_LAST}; std::vector test_msgsizes = {0}; std::vector test_dtypes = {(ucc_datatype_t)-1}; std::vector test_ops = {(ucc_reduction_op_t)-1}; std::vector test_counts_vsize = {TEST_FLAG_VSIZE_64BIT}; std::vector test_displ_vsize = {TEST_FLAG_VSIZE_64BIT}; void **onesided_bufs; if (inplace && !ucc_coll_inplace_supported(c)) { continue; } if (ucc_coll_is_rooted(c)) { roots = gen_roots(team); } if (ucc_coll_has_memtype(c)) { test_memtypes = mtypes; } if (ucc_coll_has_msgrange(c)) { test_msgsizes = msgsizes; } if (ucc_coll_has_datatype(c)) { test_dtypes = dtypes; } if (ucc_coll_has_op(c)) { test_ops = ops; } if (ucc_coll_has_bits(c)) { test_counts_vsize = counts_vsize; test_displ_vsize = displs_vsize; } for (auto r : roots) { for (auto mt: test_memtypes) { if (triggered && !ucc_coll_triggered_supported(mt)) { rst.push_back(std::make_tuple(c, UCC_ERR_NOT_IMPLEMENTED)); continue; } if ((c == UCC_COLL_TYPE_ALLTOALL || c == UCC_COLL_TYPE_ALLTOALLV) && team.ctx != ctx) { /* onesided alltoall */ if (mt != UCC_MEMORY_TYPE_HOST) { continue; } else { onesided_bufs = onesided_buffers; } } else { onesided_bufs = nullptr; } for (auto m: test_msgsizes) { for (auto dt: test_dtypes) { for (auto op: test_ops) { if (ucc_coll_args_is_reduction(c) && !ucc_coll_reduce_supported(op, dt)) { continue; } if (mt != UCC_MEMORY_TYPE_HOST && (dt == UCC_DT_FLOAT128 || dt == UCC_DT_FLOAT128_COMPLEX)) { continue; } for (auto count_bits: test_counts_vsize) { for (auto displ_bits: test_displ_vsize) { params.root = r; params.mt = mt; params.msgsize = m; params.dt = dt; params.op = op; params.count_bits = count_bits; params.displ_bits = displ_bits; params.buffers = onesided_bufs; auto tcs = TestCase::init(team, c, nt, params); auto res = exec_tests(tcs, triggered, persistent); rst.insert(rst.end(), res.begin(), res.end()); } } } } } } } } } } typedef struct ucc_test_thread { pthread_t thread; int id; UccTestMpi * test; std::vector rst; } ucc_test_thread_t; static void *thread_start(void *arg) { ucc_test_thread_t *t = (ucc_test_thread_t *)arg; #if defined(HAVE_CUDA) || defined(HAVE_HIP) set_gpu_device(test_gpu_set_device); #endif t->test->run_all_at_team(t->test->teams[t->id], t->rst); return 0; } void UccTestMpi::run_all(bool is_onesided) { if (UCC_THREAD_MULTIPLE == tm) { int n_threads = teams.size(); std::vector threads(n_threads); void * ret; for (int i = 0; i < n_threads; i++) { threads[i].id = i; threads[i].test = this; pthread_create(&threads[i].thread, NULL, thread_start, (void *)&threads[i]); } for (int i = 0; i < n_threads; i++) { pthread_join(threads[i].thread, &ret); results.insert(results.end(), threads[i].rst.begin(), threads[i].rst.end()); } } else { if (!is_onesided) { for (auto &t : teams) { run_all_at_team(t, results); } } else { for (auto &t : onesided_teams) { run_all_at_team(t, results); } } } } std::vector UccTestMpi::gen_roots(ucc_test_team_t &team) { int size; std::vector _roots; MPI_Comm_size(team.comm, &size); std::default_random_engine eng; eng.seed(123); std::uniform_int_distribution urd(0, size-1); switch(root_type) { case ROOT_SINGLE: _roots = std::vector({ucc_min(root_value, size-1)}); break; case ROOT_RANDOM: _roots.resize(root_value); for (unsigned i = 0; i < _roots.size(); i++) { _roots[i] = urd(eng); } break; case ROOT_ALL: _roots.resize(size); std::iota(_roots.begin(), _roots.end(), 0); break; default: assert(0); } return _roots; } openucx-ucc-ec0bc8a/test/mpi/test_reduce_scatter.cc0000664000175000017500000000743215133731560023016 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestReduceScatter::TestReduceScatter(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_REDUCE_SCATTER, params) { size_t dt_size = ucc_dt_size(params.dt); size_t count = msgsize / dt_size; int rank, comm_size; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &comm_size); op = params.op; dt = params.dt; if (skip_reduce(test_max_size < msgsize, TEST_SKIP_MEM_LIMIT, team.comm) || skip_reduce((count < comm_size), TEST_SKIP_NOT_SUPPORTED, team.comm)) { return; } check_buf = ucc_malloc(msgsize, "check buf"); UCC_MALLOC_CHECK(check_buf); count = count - (count % comm_size); msgsize = count * dt_size; if (inplace) { args.dst.info.count = count; UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; } else { args.dst.info.count = count / comm_size; UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize / comm_size, mem_type)); UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; sbuf = sbuf_mc_header->addr; args.src.info.buffer = sbuf; args.src.info.count = count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } args.op = op; args.dst.info.buffer = rbuf; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestReduceScatter::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; void *buf; int rank; MPI_Comm_rank(team.comm, &rank); if (inplace) { buf = rbuf; } else { buf = sbuf; } init_buffer(buf, count, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, buf, count * dt_size, UCC_MEMORY_TYPE_HOST, mem_type)); return UCC_OK; } ucc_status_t TestReduceScatter::check() { ucc_status_t status; int comm_rank, comm_size, completed; size_t block_size, block_count; MPI_Request req; MPI_Comm_rank(team.comm, &comm_rank); MPI_Comm_size(team.comm, &comm_size); block_size = msgsize / comm_size; block_count = block_size / ucc_dt_size(dt); MPI_Ireduce_scatter_block(MPI_IN_PLACE, check_buf, block_count, ucc_dt_to_mpi(dt), op == UCC_OP_AVG ? MPI_SUM : ucc_op_to_mpi(op), team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); if (op == UCC_OP_AVG) { status = divide_buffer(check_buf, team.team->size, block_count, dt); if (status != UCC_OK) { return status; } } if (inplace) { return compare_buffers(PTR_OFFSET(rbuf, comm_rank * block_size), check_buf, block_count, dt, mem_type); } return compare_buffers(rbuf, check_buf, block_count, dt, mem_type); } TestReduceScatter::~TestReduceScatter() {} std::string TestReduceScatter::str() { return std::string("tc=")+ucc_coll_type_str(args.coll_type) + " team=" + team_str(team.type) + " msgsize=" + std::to_string(msgsize) + " inplace=" + (inplace ? "1" : "0") + " persistent=" + (persistent ? "1" : "0") + " dt=" + ucc_datatype_str(dt) + " op=" + ucc_reduction_op_str(op); } openucx-ucc-ec0bc8a/test/mpi/test_reduce_scatterv.cc0000664000175000017500000001026115133731560023176 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestReduceScatterv::TestReduceScatterv(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_REDUCE_SCATTERV, params) { size_t dt_size = ucc_dt_size(params.dt); size_t count = msgsize / dt_size; int rank, comm_size; counts = NULL; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &comm_size); op = params.op; dt = params.dt; if (skip_reduce(test_max_size < msgsize, TEST_SKIP_MEM_LIMIT, team.comm)) { return; } counts = (int *)ucc_malloc(comm_size * sizeof(uint32_t), "counts buf"); UCC_MALLOC_CHECK(counts); size_t left = count; size_t total = 0; for (int i = 0; i < comm_size; i++) { size_t c = 2 + i * 2; if (left < c) { c = left; } if (i == comm_size - 1) { counts[i] = left; } else { counts[i] = c; } if (left > 0) { left -= c; } total += counts[i]; } ucc_assert(total == count); check_buf = ucc_malloc(msgsize, "check buf"); UCC_MALLOC_CHECK(check_buf); if (inplace) { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; } else { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, counts[rank] * dt_size, mem_type)); UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; sbuf = sbuf_mc_header->addr; args.src.info.buffer = sbuf; args.src.info.count = count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; } args.op = op; args.dst.info_v.counts = (ucc_count_t *)counts; args.dst.info_v.buffer = rbuf; args.dst.info_v.datatype = dt; args.dst.info_v.mem_type = mem_type; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestReduceScatterv::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; void * buf; int rank; MPI_Comm_rank(team.comm, &rank); if (inplace) { buf = rbuf; } else { buf = sbuf; } init_buffer(buf, count, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, buf, count * dt_size, UCC_MEMORY_TYPE_HOST, mem_type)); return UCC_OK; } ucc_status_t TestReduceScatterv::check() { ucc_status_t status; int comm_rank, comm_size, completed; MPI_Request req; size_t offset; MPI_Comm_rank(team.comm, &comm_rank); MPI_Comm_size(team.comm, &comm_size); MPI_Ireduce_scatter(MPI_IN_PLACE, check_buf, counts, ucc_dt_to_mpi(dt), op == UCC_OP_AVG ? MPI_SUM : ucc_op_to_mpi(op), team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while (!completed); if (op == UCC_OP_AVG) { status = divide_buffer(check_buf, team.team->size, counts[comm_rank], dt); if (status != UCC_OK) { return status; } } if (inplace) { offset = 0; for (int i = 0; i < comm_rank; i++) { offset += counts[i]; } return compare_buffers(PTR_OFFSET(rbuf, offset * ucc_dt_size(dt)), check_buf, counts[comm_rank], dt, mem_type); } return compare_buffers(rbuf, check_buf, counts[comm_rank], dt, mem_type); } TestReduceScatterv::~TestReduceScatterv() { if (counts) { ucc_free(counts); } } std::string TestReduceScatterv::str() { return std::string("tc=") + ucc_coll_type_str(args.coll_type) + " team=" + team_str(team.type) + " msgsize=" + std::to_string(msgsize) + " inplace=" + (inplace ? "1" : "0") + " persistent=" + (persistent ? "1" : "0") + " dt=" + ucc_datatype_str(dt) + " op=" + ucc_reduction_op_str(op); } openucx-ucc-ec0bc8a/test/mpi/test_reduce.cc0000664000175000017500000000624315133731560021270 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "test_mpi.h" #include "mpi_util.h" TestReduce::TestReduce(ucc_test_team_t &_team, TestCaseParams ¶ms) : TestCase(_team, UCC_COLL_TYPE_REDUCE, params) { size_t dt_size = ucc_dt_size(params.dt); size_t count = msgsize/dt_size; int rank; MPI_Comm_rank(team.comm, &rank); dt = params.dt; op = params.op; root = params.root; if (skip_reduce(test_max_size < msgsize, TEST_SKIP_MEM_LIMIT, team.comm)) { return; } check_buf = ucc_malloc(msgsize, "check buf"); UCC_MALLOC_CHECK(check_buf); if (rank == root) { UCC_CHECK(ucc_mc_alloc(&rbuf_mc_header, msgsize, mem_type)); rbuf = rbuf_mc_header->addr; args.dst.info.buffer = rbuf; args.dst.info.count = count; args.dst.info.datatype = dt; args.dst.info.mem_type = mem_type; } if ((rank != root) || (!inplace)) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize, mem_type)); sbuf = sbuf_mc_header->addr; } args.op = op; args.src.info.buffer = sbuf; args.src.info.count = count; args.src.info.datatype = dt; args.src.info.mem_type = mem_type; args.root = root; UCC_CHECK(set_input()); UCC_CHECK_SKIP(ucc_collective_init(&args, &req, team.team), test_skip); } ucc_status_t TestReduce::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); size_t count = msgsize / dt_size; int rank; void *buf; MPI_Comm_rank(team.comm, &rank); if (inplace && rank == root) { buf = rbuf; } else { buf = sbuf; } init_buffer(buf, count, dt, mem_type, rank * (iter_persistent + 1)); UCC_CHECK(ucc_mc_memcpy(check_buf, buf, count * dt_size, UCC_MEMORY_TYPE_HOST, mem_type)); return UCC_OK; } ucc_status_t TestReduce::check() { ucc_status_t status; size_t count = args.src.info.count; int rank, completed; MPI_Request req; MPI_Comm_rank(team.comm, &rank); MPI_Ireduce((root == rank) ? MPI_IN_PLACE : check_buf, check_buf, count, ucc_dt_to_mpi(dt), op == UCC_OP_AVG ? MPI_SUM : ucc_op_to_mpi(op), root, team.comm, &req); do { MPI_Test(&req, &completed, MPI_STATUS_IGNORE); ucc_context_progress(team.ctx); } while(!completed); if (rank == root && op == UCC_OP_AVG) { status = divide_buffer(check_buf, team.team->size, count, dt); if (status != UCC_OK) { return status; } } return (rank != root) ? UCC_OK : compare_buffers(rbuf, check_buf, count, dt, mem_type); } std::string TestReduce::str() { return std::string("tc=")+ucc_coll_type_str(args.coll_type) + " team=" + team_str(team.type) + " msgsize=" + std::to_string(msgsize) + " inplace=" + (inplace ? "1" : "0") + " persistent=" + (persistent ? "1" : "0") + " dt=" + ucc_datatype_str(dt) + " op=" + ucc_reduction_op_str(op) + " root=" + std::to_string(root); } openucx-ucc-ec0bc8a/test/mpi/buffer.cc0000664000175000017500000001655215133731560020237 0ustar alastairalastair/* * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include #include #include BEGIN_C_DECLS #include "components/mc/ucc_mc.h" END_C_DECLS #include "test_mpi.h" #include #define TEST_MPI_FP_EPSILON 1e-5 template void init_buffer_host(void *buf, size_t count, int _value) { T *ptr = (T *)buf; for (size_t i = 0; i < count; i++) { ptr[i] = (T)((_value + i + 1) % 128); } } void init_buffer(void *_buf, size_t count, ucc_datatype_t dt, ucc_memory_type_t mt, int value, int offset) { void *buf = NULL; if (mt == UCC_MEMORY_TYPE_CUDA || mt == UCC_MEMORY_TYPE_ROCM) { buf = ucc_malloc(count * ucc_dt_size(dt), "buf"); UCC_MALLOC_CHECK(buf); } else if (mt == UCC_MEMORY_TYPE_HOST || mt == UCC_MEMORY_TYPE_CUDA_MANAGED) { buf = _buf; } else { std::cerr << "Unsupported mt\n"; MPI_Abort(MPI_COMM_WORLD, -1); } value += offset; switch(dt) { case UCC_DT_INT8: init_buffer_host(buf, count, value); break; case UCC_DT_UINT8: init_buffer_host(buf, count, value); break; case UCC_DT_INT16: init_buffer_host(buf, count, value); break; case UCC_DT_UINT16: init_buffer_host(buf, count, value); break; case UCC_DT_INT32: init_buffer_host(buf, count, value); break; case UCC_DT_UINT32: init_buffer_host(buf, count, value); break; case UCC_DT_INT64: init_buffer_host(buf, count, value); break; case UCC_DT_UINT64: init_buffer_host(buf, count, value); break; case UCC_DT_FLOAT32: init_buffer_host(buf, count, value); break; case UCC_DT_FLOAT64: init_buffer_host(buf, count, value); break; case UCC_DT_FLOAT128: init_buffer_host(buf, count, value); break; case UCC_DT_FLOAT32_COMPLEX: init_buffer_host(buf, count, value); break; case UCC_DT_FLOAT64_COMPLEX: init_buffer_host(buf, count, value); break; case UCC_DT_FLOAT128_COMPLEX: init_buffer_host(buf, count, value); break; default: std::cerr << "Unsupported dt\n"; MPI_Abort(MPI_COMM_WORLD, -1); break; } if (UCC_MEMORY_TYPE_HOST != mt && UCC_MEMORY_TYPE_CUDA_MANAGED != mt) { UCC_CHECK(ucc_mc_memcpy(_buf, buf, count * ucc_dt_size(dt), mt, UCC_MEMORY_TYPE_HOST)); ucc_free(buf); } } template static inline bool is_equal(T a, T b, T epsilon) { return fabs(a - b) <= ((fabs(a) < fabs(b) ? fabs(b) : fabs(a)) * epsilon); } static inline bool is_equal_complex(float _Complex a, float _Complex b, double epsilon) { return (is_equal(crealf(a), crealf(b), (float)epsilon) && is_equal(cimagf(a), cimagf(b), (float)epsilon)); } static inline bool is_equal_complex(double _Complex a, double _Complex b, double epsilon) { return (is_equal(creal(a), creal(b), epsilon) && is_equal(cimag(a), cimag(b), epsilon)); } static inline bool is_equal_complex(long double _Complex a, long double _Complex b, double epsilon) { return (is_equal(creall(a), creall(b), (long double)epsilon) && is_equal(cimagl(a), cimagl(b), (long double)epsilon)); } template ucc_status_t compare_buffers_fp(T *b1, T *b2, size_t count) { T epsilon = (T)TEST_MPI_FP_EPSILON; for (size_t i = 0; i < count; i++) { if (!is_equal(b1[i], b2[i], epsilon)) { return UCC_ERR_NO_MESSAGE; } } return UCC_OK; } template ucc_status_t compare_buffers_complex(T *b1, T *b2, size_t count) { double epsilon = (double)TEST_MPI_FP_EPSILON; for (size_t i = 0; i < count; i++) { if (!is_equal_complex(b1[i], b2[i], epsilon)) { return UCC_ERR_NO_MESSAGE; } } return UCC_OK; } ucc_status_t compare_buffers(void *_rst, void *expected, size_t count, ucc_datatype_t dt, ucc_memory_type_t mt) { ucc_status_t status = UCC_ERR_NO_MESSAGE; ucc_mc_buffer_header_t *rst_mc_header; void *rst = NULL; if (UCC_MEMORY_TYPE_HOST == mt || mt == UCC_MEMORY_TYPE_CUDA_MANAGED) { rst = _rst; } else if (UCC_MEMORY_TYPE_CUDA == mt || UCC_MEMORY_TYPE_ROCM == mt) { UCC_ALLOC_COPY_BUF(rst_mc_header, UCC_MEMORY_TYPE_HOST, _rst, mt, count * ucc_dt_size(dt)); rst = rst_mc_header->addr; } else { std::cerr << "Unsupported mt\n"; MPI_Abort(MPI_COMM_WORLD, -1); } if (dt == UCC_DT_FLOAT32) { status = compare_buffers_fp((float*)rst, (float*)expected, count); } else if (dt == UCC_DT_FLOAT64) { status = compare_buffers_fp((double*)rst, (double*)expected, count); } else if (dt == UCC_DT_FLOAT128) { status = compare_buffers_fp( (long double *)rst, (long double *)expected, count); } else if (dt == UCC_DT_FLOAT32_COMPLEX) { status = compare_buffers_complex( (float _Complex *)rst, (float _Complex *)expected, count); } else if (dt == UCC_DT_FLOAT64_COMPLEX) { status = compare_buffers_complex( (double _Complex *)rst, (double _Complex *)expected, count); } else if (dt == UCC_DT_FLOAT128_COMPLEX) { status = compare_buffers_complex( (long double _Complex *)rst, (long double _Complex *)expected, count); } else { status = memcmp(rst, expected, count*ucc_dt_size(dt)) ? UCC_ERR_NO_MESSAGE : UCC_OK; } if (UCC_MEMORY_TYPE_HOST != mt && UCC_MEMORY_TYPE_CUDA_MANAGED != mt) { UCC_CHECK(ucc_mc_free(rst_mc_header)); } return status; } template void divide_buffers_fp(T *b, size_t divider, size_t count) { for (size_t i = 0; i < count; i++) { b[i] = b[i] / (double)divider; } } ucc_status_t divide_buffer(void *expected, size_t divider, size_t count, ucc_datatype_t dt) { if (dt == UCC_DT_FLOAT32) { divide_buffers_fp((float *)expected, divider, count); } else if (dt == UCC_DT_FLOAT64) { divide_buffers_fp((double *)expected, divider, count); } else if (dt == UCC_DT_FLOAT128) { divide_buffers_fp((long double *)expected, divider, count); } else if (dt == UCC_DT_FLOAT32_COMPLEX) { divide_buffers_fp((float _Complex *)expected, divider, count); } else if (dt == UCC_DT_FLOAT64_COMPLEX) { divide_buffers_fp((double _Complex *)expected, divider, count); } else if (dt == UCC_DT_FLOAT128_COMPLEX) { divide_buffers_fp( (long double _Complex *)expected, divider, count); } else { std::cerr << "Unsupported dt for avg\n"; return UCC_ERR_NO_MESSAGE; } return UCC_OK; } openucx-ucc-ec0bc8a/test/gtest/0000775000175000017500000000000015133731560017007 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/tl/0000775000175000017500000000000015133731560017426 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/tl/mlx5/0000775000175000017500000000000015133731560020313 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc0000664000175000017500000003636615133731560024137 0ustar alastairalastair/** * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_tl_mlx5_wqe.h" #include "utils/arch/cpu.h" #include #include // Rounds up a given integer to the closet power of two static int roundUpToPowerOfTwo(int a) { int b = 1; while (b < a) { b *= 2; } return b; } // Returns whether a matrix fits ConnectX-7 HW limitations static bool doesMatrixFit(int nrows, int ncols, int elem_size) { // Compute the matrix size as is done by ConnectX-7 HW int matrix_size = nrows * std::max(128, roundUpToPowerOfTwo(ncols) * roundUpToPowerOfTwo(std::max(elem_size, 8))); return matrix_size <= pow(2, 13) //= 8Kb && elem_size <= 128 && nrows <= 64 && ncols <= 64; } UCC_TEST_P(test_tl_mlx5_transpose, transposeWqe) { int nrows = std::get<0>(GetParam()); int ncols = std::get<1>(GetParam()); int elem_size = sizeof(DT) * std::get<2>(GetParam()); int completions_num = 0; DT src[nrows][ncols * elem_size], dst[ncols][nrows * elem_size]; struct ibv_wc wcs[1]; struct ibv_mr *src_mr, *dst_mr; int i, j, k; if (!is_cx7_vendor_id()) { GTEST_SKIP() << "The test needs CX7"; } // Skips if do not match HW limitations if (!doesMatrixFit(nrows, ncols, elem_size)) { GTEST_SKIP(); } for (i = 0; i < nrows; i++) { for (j = 0; j < ncols; j++) { for (k = 0; k < elem_size; k++) { src[i][j * elem_size + k] = (i * nrows * elem_size + j * elem_size + k) % 256; dst[j][i * elem_size + k] = 0; } } } src_mr = ibv_reg_mr(pd, src, nrows * ncols * elem_size, IBV_ACCESS_LOCAL_WRITE); GTEST_ASSERT_NE(nullptr, src_mr); dst_mr = ibv_reg_mr(pd, dst, nrows * ncols * elem_size, IBV_ACCESS_LOCAL_WRITE); GTEST_ASSERT_NE(nullptr, dst_mr); ibv_wr_start(qp.qp_ex); post_transpose(qp.qp, src_mr->lkey, dst_mr->rkey, (uintptr_t)src, (uintptr_t)dst, elem_size, ncols, nrows, IBV_SEND_SIGNALED); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); while (!completions_num) { completions_num = ibv_poll_cq(cq, 1, wcs); } GTEST_ASSERT_EQ(completions_num, 1); GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); for (i = 0; i < nrows; i++) { for (j = 0; j < ncols; j++) { for (k = 0; k < elem_size; k++) { GTEST_ASSERT_EQ(src[i][j * elem_size + k], dst[j][i * elem_size + k]); } } } GTEST_ASSERT_EQ(ibv_dereg_mr(src_mr), UCC_OK); GTEST_ASSERT_EQ(ibv_dereg_mr(dst_mr), UCC_OK); } INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_transpose, ::testing::Combine(::testing::Values(1, 7, 32, 64), ::testing::Values(1, 5, 32, 64), ::testing::Values(1, 3, 8, 128))); UCC_TEST_P(test_tl_mlx5_rdma_write, RdmaWriteWqe) { struct ibv_sge sg; struct ibv_send_wr wr; bufsize = GetParam(); buffers_init(); CHECK_TEST_STATUS(); memset(&sg, 0, sizeof(sg)); sg.addr = (uintptr_t)src; sg.length = bufsize; sg.lkey = src_mr->lkey; memset(&wr, 0, sizeof(wr)); wr.wr_id = 0; wr.sg_list = &sg; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; wr.next = NULL; wr.wr.rdma.remote_addr = (uintptr_t)dst; wr.wr.rdma.rkey = dst_mr->rkey; // This work request is posted with wr_id = 0 GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); wait_for_completion(); CHECK_TEST_STATUS(); validate_buffers(); } UCC_TEST_P(test_tl_mlx5_rdma_write, CustomRdmaWriteWqe) { bufsize = GetParam(); buffers_init(); CHECK_TEST_STATUS(); ibv_wr_start(qp.qp_ex); post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)src, bufsize, src_mr->lkey, (uintptr_t)dst, dst_mr->rkey, IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); wait_for_completion(); CHECK_TEST_STATUS(); validate_buffers(); } INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_rdma_write, ::testing::Values(1, 31, 128, 1024)); UCC_TEST_P(test_tl_mlx5_dm, MemcpyToDeviceMemory) { bufsize = GetParam(); buffers_init(); CHECK_TEST_STATUS(); if (!dm_ptr) { return; } if (bufsize % 4 != 0) { GTEST_SKIP() << "for memcpy involving device memory, buffer size " << "must be a multiple of 4"; } GTEST_ASSERT_EQ(ibv_memcpy_to_dm(dm_ptr, 0, (void *)src, bufsize), 0); GTEST_ASSERT_EQ(ibv_memcpy_from_dm((void *)dst, dm_ptr, 0, bufsize), 0); validate_buffers(); } UCC_TEST_P(test_tl_mlx5_dm, RdmaToDeviceMemory) { struct ibv_sge sg; struct ibv_send_wr wr; bufsize = GetParam(); buffers_init(); CHECK_TEST_STATUS(); if (!dm_ptr) { return; } // RDMA write from host source to device memory memset(&sg, 0, sizeof(sg)); sg.addr = (uintptr_t)src; sg.length = bufsize; sg.lkey = src_mr->lkey; memset(&wr, 0, sizeof(wr)); wr.wr_id = 0; wr.sg_list = &sg; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; wr.next = NULL; wr.wr.rdma.remote_addr = (uintptr_t)0; wr.wr.rdma.rkey = dm_mr->rkey; GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); wait_for_completion(); CHECK_TEST_STATUS(); // RDMA write from device memory to host destination memset(&sg, 0, sizeof(sg)); sg.addr = (uintptr_t)0; sg.length = bufsize; sg.lkey = dm_mr->lkey; memset(&wr, 0, sizeof(wr)); wr.wr_id = 0; wr.sg_list = &sg; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; wr.next = NULL; wr.wr.rdma.remote_addr = (uintptr_t)dst; wr.wr.rdma.rkey = dst_mr->rkey; GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); wait_for_completion(); CHECK_TEST_STATUS(); validate_buffers(); } UCC_TEST_P(test_tl_mlx5_dm, CustomRdmaToDeviceMemory) { bufsize = GetParam(); buffers_init(); CHECK_TEST_STATUS(); if (!dm_ptr) { return; } // RDMA write from host source to device memory ibv_wr_start(qp.qp_ex); post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)src, bufsize, src_mr->lkey, (uintptr_t)0, dm_mr->rkey, IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); wait_for_completion(); CHECK_TEST_STATUS(); // RDMA write from device memory to host destination ibv_wr_start(qp.qp_ex); post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)0, bufsize, dm_mr->lkey, (uintptr_t)dst, dst_mr->rkey, IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); wait_for_completion(); CHECK_TEST_STATUS(); validate_buffers(); } INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_dm, ::testing::Values(1, 12, 31, 32, 8192, 8193, 32768, 65536)); UCC_TEST_P(test_tl_mlx5_wait_on_data, waitOnDataWqe) { uint64_t wait_on_value = std::get<0>(GetParam()); uint64_t init_ctrl_value = std::get<1>(GetParam()); uint64_t buffer[3]; volatile uint64_t *ctrl, *src, *dst; int completions_num; struct ibv_wc wcs[1]; struct ibv_mr * buffer_mr; struct ibv_sge sg; struct ibv_send_wr wr; if (!is_cx7_vendor_id()) { GTEST_SKIP() << "The test needs CX7"; } memset(buffer, 0, 3 * sizeof(uint64_t)); buffer_mr = ibv_reg_mr(pd, buffer, 3 * sizeof(uint64_t), IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); GTEST_ASSERT_NE(nullptr, buffer_mr); ctrl = &buffer[0]; src = &buffer[1]; dst = &buffer[2]; *ctrl = init_ctrl_value; memset(&sg, 0, sizeof(sg)); sg.addr = (uintptr_t)src; sg.length = sizeof(uint64_t); sg.lkey = buffer_mr->lkey; memset(&wr, 0, sizeof(wr)); wr.wr_id = 0; wr.sg_list = &sg; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_FENCE; wr.next = NULL; wr.wr.rdma.remote_addr = (uintptr_t)dst; wr.wr.rdma.rkey = buffer_mr->rkey; // This work request is posted with wr_id = 1 GTEST_ASSERT_EQ(post_wait_on_data(qp.qp, wait_on_value, buffer_mr->lkey, (uintptr_t)ctrl, nullptr), UCC_OK); // This work request is posted with wr_id = 0 GTEST_ASSERT_EQ(ibv_post_send(qp.qp, &wr, NULL), 0); sleep(1); *src = 0xdeadbeef; //memory barrier ucc_memory_cpu_fence(); *ctrl = wait_on_value; while (1) { completions_num = ibv_poll_cq(cq, 1, wcs); if (completions_num != 0) { GTEST_ASSERT_EQ(completions_num, 1); GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); if (wcs[0].wr_id == 0) { break; } } } //validation GTEST_ASSERT_EQ(*dst, *src); GTEST_ASSERT_EQ(ibv_dereg_mr(buffer_mr), UCC_OK); } INSTANTIATE_TEST_SUITE_P( , test_tl_mlx5_wait_on_data, ::testing::Combine(::testing::Values(1, 1024, 1025, 0xF0F30F00, 0xFFFFFFFF), ::testing::Values(0, 0xF0F30F01))); UCC_TEST_P(test_tl_mlx5_umr_wqe, umrWqe) { uint16_t nbr_srcs = std::get<0>(GetParam()); uint32_t bytes_count = std::get<1>(GetParam()); uint32_t repeat_count = std::get<2>(GetParam()); uint32_t bytes_skip = std::get<3>(GetParam()); int src_size = (bytes_count + bytes_skip) * repeat_count; int dst_size = bytes_count * nbr_srcs * repeat_count; int send_mem_access_flags = 0; void *umr_entries_buf = nullptr; int recv_mem_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; DT src[nbr_srcs][src_size], dst[dst_size]; struct ibv_mr * src_mr[nbr_srcs], *dst_mr, *umr_entries_mr; struct mlx5dv_mkey * umr_mkey; struct mlx5dv_mkey_init_attr umr_mkey_init_attr; size_t umr_buf_size; struct mlx5dv_mr_interleaved mkey_entries[nbr_srcs]; struct ibv_wc wcs[1]; int i, j, completions_num, src_index, repetition_count, offset; // Setup src and dst buffers for (i = 0; i < nbr_srcs; i++) { for (j = 0; j < src_size; j++) { src[i][j] = (i * (src_size + 7) + j) % 255; } src_mr[i] = ibv_reg_mr(pd, src[i], src_size, send_mem_access_flags); GTEST_ASSERT_NE(nullptr, src_mr[i]); } memset(dst, 0, dst_size); dst_mr = ibv_reg_mr(pd, dst, dst_size, recv_mem_access_flags); GTEST_ASSERT_NE(nullptr, dst_mr); // UMR umr_buf_size = ucc_align_up( sizeof(struct mlx5_wqe_umr_repeat_ent_seg) * (nbr_srcs + 1), 64); GTEST_ASSERT_EQ(ucc_posix_memalign(&umr_entries_buf, 2048, umr_buf_size), 0); umr_entries_mr = ibv_reg_mr(pd, umr_entries_buf, umr_buf_size, send_mem_access_flags); GTEST_ASSERT_NE(nullptr, umr_entries_mr); memset(&umr_mkey_init_attr, 0, sizeof(umr_mkey_init_attr)); umr_mkey_init_attr.pd = pd; umr_mkey_init_attr.create_flags = MLX5DV_MKEY_INIT_ATTR_FLAGS_INDIRECT; umr_mkey_init_attr.max_entries = nbr_srcs + 1; //+1 for the "repeat block" umr_mkey = mlx5dv_create_mkey(&umr_mkey_init_attr); GTEST_ASSERT_NE(nullptr, umr_mkey); GTEST_ASSERT_GE(umr_mkey_init_attr.max_entries, nbr_srcs + 1); for (i = 0; i < nbr_srcs; i++) { mkey_entries[i].addr = (uintptr_t)src[i]; mkey_entries[i].bytes_count = bytes_count; mkey_entries[i].bytes_skip = bytes_skip; mkey_entries[i].lkey = src_mr[i]->lkey; } post_umr(umr_qp.qp, umr_mkey, send_mem_access_flags, repeat_count, nbr_srcs, mkey_entries, (uint32_t)umr_entries_mr->lkey, umr_entries_buf); completions_num = 0; while (!completions_num) { completions_num = ibv_poll_cq(cq, 1, wcs); } GTEST_ASSERT_EQ(completions_num, 1); GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); GTEST_ASSERT_EQ(wcs[0].wr_id, 0); // RDMA Write ibv_wr_start(qp.qp_ex); post_rdma_write(qp.qp, qpn, nullptr, (uintptr_t)0, dst_size, umr_mkey->lkey, (uintptr_t)dst, dst_mr->rkey, IBV_SEND_SIGNALED | IBV_SEND_FENCE, 0); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); completions_num = 0; while (!completions_num) { completions_num = ibv_poll_cq(cq, 1, wcs); } GTEST_ASSERT_EQ(completions_num, 1); GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); GTEST_ASSERT_EQ(wcs[0].wr_id, 0); // Verification for (i = 0; i < dst_size; i++) { src_index = (i / bytes_count) % nbr_srcs; repetition_count = (i / bytes_count) / nbr_srcs; offset = repetition_count * (bytes_count + bytes_skip) + (i % bytes_count); GTEST_ASSERT_EQ(dst[i], src[src_index][offset]); } // Tear down GTEST_ASSERT_EQ(0, mlx5dv_destroy_mkey(umr_mkey)); GTEST_ASSERT_EQ(ibv_dereg_mr(umr_entries_mr), UCC_OK); ucc_free(umr_entries_buf); for (i = 0; i < nbr_srcs; i++) { GTEST_ASSERT_EQ(ibv_dereg_mr(src_mr[i]), UCC_OK); } GTEST_ASSERT_EQ(ibv_dereg_mr(dst_mr), UCC_OK); } INSTANTIATE_TEST_SUITE_P(, test_tl_mlx5_umr_wqe, ::testing::Combine(::testing::Values(1, 129, 1024), ::testing::Values(5, 64), ::testing::Values(1, 3, 16), ::testing::Values(0, 7))); UCC_TEST_P(test_tl_mlx5_dm_alloc_reg, DeviceMemoryAllocation) { size_t buf_size = std::get<0>(GetParam()); size_t buf_num = std::get<1>(GetParam()); struct ibv_dm *ptr = nullptr; struct ibv_mr *mr = nullptr; ucc_status_t status; status = dm_alloc_reg(ctx, pd, 0, buf_size, &buf_num, &ptr, &mr, &lib); if (status == UCC_ERR_NO_MEMORY || status == UCC_ERR_NO_RESOURCE) { GTEST_SKIP() << "cannot allocate " << buf_num << " chunk(s) of size " << buf_size << " in device memory"; } GTEST_ASSERT_EQ(status, UCC_OK); ibv_dereg_mr(mr); ibv_free_dm(ptr); } INSTANTIATE_TEST_SUITE_P( , test_tl_mlx5_dm_alloc_reg, ::testing::Combine(::testing::Values(1, 2, 1024, 8191, 8192, 8193, 32768, 65536, 262144), ::testing::Values(UCC_ULUNITS_AUTO, 1, 3, 8))); openucx-ucc-ec0bc8a/test/gtest/tl/mlx5/test_tl_mlx5.h0000664000175000017500000000276615133731560023122 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #ifndef TEST_TL_MLX5_H #define TEST_TL_MLX5_H #include #include "common/test_ucc.h" #include "components/tl/mlx5/tl_mlx5.h" #include "components/tl/mlx5/tl_mlx5_dm.h" #include "components/tl/mlx5/tl_mlx5_ib.h" #define CHECK_TEST_STATUS() \ if (Test::HasFatalFailure() || Test::IsSkipped()) { \ return; \ } typedef ucc_status_t (*ucc_tl_mlx5_create_ibv_ctx_fn_t)( char **ib_devname, struct ibv_context **ctx, ucc_base_lib_t *lib); typedef int (*ucc_tl_mlx5_get_active_port_fn_t)(struct ibv_context *ctx); class test_tl_mlx5 : public ucc::test { protected: void *tl_mlx5_so_handle; public: ucc_base_lib_t lib; ucc_tl_mlx5_create_ibv_ctx_fn_t create_ibv_ctx; ucc_tl_mlx5_get_active_port_fn_t get_active_port; struct ibv_port_attr port_attr; struct ibv_context * ctx; struct ibv_pd * pd; struct ibv_cq * cq; int port; test_tl_mlx5(); virtual ~test_tl_mlx5(); virtual void SetUp() override; // Check for Mellanox/NVIDIA vendor ID (0x02c9) and CX7 (MT4129) vendor_part_id bool is_cx7_vendor_id() const { struct ibv_device_attr device_attr; ibv_query_device(ctx, &device_attr); return device_attr.vendor_id == 0x02c9 && device_attr.vendor_part_id == 4129; } }; #endif openucx-ucc-ec0bc8a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.h0000664000175000017500000001641415133731560023771 0ustar alastairalastair/** * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_tl_mlx5.h" #include "test_tl_mlx5_qps.h" #include "components/tl/mlx5/tl_mlx5_wqe.h" #define DT uint8_t typedef ucc_status_t (*ucc_tl_mlx5_post_rdma_fn_t)( struct ibv_qp *qp, uint32_t qpn, struct ibv_ah *ah, uintptr_t src_mkey_addr, size_t len, uint32_t src_mr_lkey, uintptr_t dst_addr, uint32_t dst_mr_key, int send_flags, uint64_t wr_id); typedef ucc_status_t (*ucc_tl_mlx5_post_transpose_fn_t)( struct ibv_qp *qp, uint32_t src_mr_lkey, uint32_t dst_mr_key, uintptr_t src_mkey_addr, uintptr_t dst_addr, uint32_t element_size, uint16_t ncols, uint16_t nrows, int send_flags); typedef ucc_status_t (*ucc_tl_mlx5_post_wait_on_data_fn_t)(struct ibv_qp *qp, uint64_t value, uint32_t lkey, uintptr_t addr, void *task_ptr); typedef ucc_status_t (*ucc_tl_mlx5_post_umr_fn_t)( struct ibv_qp *qp, struct mlx5dv_mkey *dv_mkey, uint32_t access_flags, uint32_t repeat_count, uint16_t num_entries, struct mlx5dv_mr_interleaved *data, uint32_t ptr_mkey, void *ptr_address); typedef ucc_status_t (*ucc_tl_mlx5_dm_alloc_reg_fn_t)( struct ibv_context *ib_ctx, struct ibv_pd *pd, int dm_host, size_t buf_size, size_t *buf_num_p, struct ibv_dm **ptr, struct ibv_mr **mr, ucc_base_lib_t *lib); // (msgsize) using RdmaWriteParams = int; // (buf_size) using DmParams = int; // (nrows, ncols, element_size) using TransposeParams = std::tuple; // (nbr_srcs, bytes_count, repeat_count, bytes_skip) using UmrParams = std::tuple; // (buffer_size, buffer_nums) using AllocDmParams = std::tuple; // (wait_on_value, init_ctrl_value) using WaitOnDataParams = std::tuple; class test_tl_mlx5_wqe : public test_tl_mlx5_rc_qp { public: ucc_tl_mlx5_post_rdma_fn_t post_rdma_write; ucc_tl_mlx5_post_transpose_fn_t post_transpose; ucc_tl_mlx5_post_wait_on_data_fn_t post_wait_on_data; ucc_tl_mlx5_post_umr_fn_t post_umr; void SetUp() { test_tl_mlx5_rc_qp::SetUp(); CHECK_TEST_STATUS(); post_rdma_write = (ucc_tl_mlx5_post_rdma_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_post_rdma"); ASSERT_EQ(nullptr, dlerror()); post_transpose = (ucc_tl_mlx5_post_transpose_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_post_transpose"); ASSERT_EQ(nullptr, dlerror()); post_wait_on_data = (ucc_tl_mlx5_post_wait_on_data_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_post_wait_on_data"); ASSERT_EQ(nullptr, dlerror()); post_umr = (ucc_tl_mlx5_post_umr_fn_t)dlsym(tl_mlx5_so_handle, "ucc_tl_mlx5_post_umr"); ASSERT_EQ(nullptr, dlerror()); create_qp(); CHECK_TEST_STATUS(); connect_qp_loopback(); CHECK_TEST_STATUS(); create_umr_qp(); } }; class test_tl_mlx5_transpose : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { }; class test_tl_mlx5_wait_on_data : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { }; class test_tl_mlx5_umr_wqe : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { }; class test_tl_mlx5_rdma_write : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { public: DT *src = nullptr; DT *dst = nullptr; struct ibv_mr *src_mr = nullptr; struct ibv_mr *dst_mr = nullptr; int bufsize; void buffers_init() { src = (DT *)malloc(bufsize); GTEST_ASSERT_NE(src, nullptr); dst = (DT *)malloc(bufsize); GTEST_ASSERT_NE(dst, nullptr); for (int i = 0; i < bufsize; i++) { src[i] = i % 256; dst[i] = 0; } src_mr = ibv_reg_mr(pd, src, bufsize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); GTEST_ASSERT_NE(nullptr, src_mr); dst_mr = ibv_reg_mr(pd, dst, bufsize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); GTEST_ASSERT_NE(nullptr, dst_mr); } void wait_for_completion() { int completions_num = 0; struct ibv_wc wcs[1]; while (!completions_num) { completions_num = ibv_poll_cq(cq, 1, wcs); } GTEST_ASSERT_EQ(completions_num, 1); GTEST_ASSERT_EQ(wcs[0].status, IBV_WC_SUCCESS); } void validate_buffers() { for (int i = 0; i < bufsize; i++) { GTEST_ASSERT_EQ(src[i], dst[i]); } } void TearDown() { if (src_mr != nullptr) { GTEST_ASSERT_EQ(ibv_dereg_mr(src_mr), UCC_OK); } if (dst_mr != nullptr) { GTEST_ASSERT_EQ(ibv_dereg_mr(dst_mr), UCC_OK); } if (src != nullptr) { free(src); } if (dst != nullptr) { free(dst); } } }; class test_tl_mlx5_dm : public test_tl_mlx5_rdma_write { public: struct ibv_dm *dm_ptr = nullptr; struct ibv_mr *dm_mr = nullptr; struct ibv_alloc_dm_attr dm_attr; void buffers_init() { test_tl_mlx5_rdma_write::buffers_init(); CHECK_TEST_STATUS(); struct ibv_device_attr_ex attr; memset(&attr, 0, sizeof(attr)); GTEST_ASSERT_EQ(ibv_query_device_ex(ctx, NULL, &attr), 0); if (attr.max_dm_size < bufsize) { if (!attr.max_dm_size) { GTEST_SKIP() << "device doesn't support dm allocation"; } else { GTEST_SKIP() << "the requested buffer size (=" << bufsize << ") for device memory should be less than " << attr.max_dm_size; } } memset(&dm_attr, 0, sizeof(dm_attr)); dm_attr.length = bufsize; dm_ptr = ibv_alloc_dm(ctx, &dm_attr); if (!dm_ptr) { GTEST_SKIP() << "device cannot allocate a buffer of size " << bufsize; } dm_mr = ibv_reg_dm_mr(pd, dm_ptr, 0, dm_attr.length, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_ZERO_BASED); GTEST_ASSERT_NE(dm_mr, nullptr); } void TearDown() { if (dm_mr) { ibv_dereg_mr(dm_mr); } if (dm_ptr) { ibv_free_dm(dm_ptr); } test_tl_mlx5_rdma_write::TearDown(); } }; class test_tl_mlx5_dm_alloc_reg : public test_tl_mlx5_wqe, public ::testing::WithParamInterface { public: ucc_tl_mlx5_dm_alloc_reg_fn_t dm_alloc_reg; void SetUp() { test_tl_mlx5_wqe::SetUp(); CHECK_TEST_STATUS(); dm_alloc_reg = (ucc_tl_mlx5_dm_alloc_reg_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_dm_alloc_reg"); ASSERT_EQ(nullptr, dlerror()); } }; openucx-ucc-ec0bc8a/test/gtest/tl/mlx5/test_tl_mlx5_qps.h0000664000175000017500000001256115133731560023777 0ustar alastairalastair/** * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_tl_mlx5.h" class test_tl_mlx5_qp : public test_tl_mlx5 { public: ucc_tl_mlx5_ib_qp_conf_t qp_conf; virtual void SetUp() { test_tl_mlx5::SetUp(); CHECK_TEST_STATUS(); qp_conf.qp_rnr_retry = 7; qp_conf.qp_rnr_timer = 20; qp_conf.qp_retry_cnt = 7; qp_conf.qp_timeout = 18; qp_conf.qp_max_atomic = 1; qp_conf.qp_sl = 1; } }; typedef ucc_status_t (*ucc_tl_mlx5_create_umr_qp_fn_t)( struct ibv_context *ctx, struct ibv_pd *pd, struct ibv_cq *cq, int ib_port, struct ibv_qp **qp, ucc_tl_mlx5_ib_qp_conf_t *qp_conf, ucc_base_lib_t *lib); typedef ucc_status_t (*ucc_tl_mlx5_qp_connect_fn_t)( struct ibv_qp *qp, uint32_t qp_num, uint16_t lid, int port, ucc_tl_mlx5_ib_qp_conf_t *qp_conf, ucc_base_lib_t *lib); typedef ucc_status_t (*ucc_tl_mlx5_create_rc_qp_fn_t)( struct ibv_context *ctx, struct ibv_pd *pd, struct ibv_cq *cq, int tx_depth, ucc_tl_mlx5_qp_t *qp, uint32_t *qpn, ucc_base_lib_t *lib); class test_tl_mlx5_rc_qp : public test_tl_mlx5_qp { public: ucc_tl_mlx5_qp_t qp = {}; ucc_tl_mlx5_qp_t umr_qp = {}; uint32_t qpn; ucc_tl_mlx5_ib_qp_conf_t umr_qp_conf; int tx_depth; ucc_tl_mlx5_create_rc_qp_fn_t create_rc_qp; ucc_tl_mlx5_qp_connect_fn_t qp_connect; ucc_tl_mlx5_create_umr_qp_fn_t create_rc_umr_qp; virtual void SetUp() { qp.qp = NULL; umr_qp.qp = NULL; tx_depth = 4; test_tl_mlx5_qp::SetUp(); CHECK_TEST_STATUS(); umr_qp_conf = qp_conf; create_rc_qp = (ucc_tl_mlx5_create_rc_qp_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_create_rc_qp"); ASSERT_EQ(nullptr, dlerror()); qp_connect = (ucc_tl_mlx5_qp_connect_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_qp_connect"); ASSERT_EQ(nullptr, dlerror()); create_rc_umr_qp = (ucc_tl_mlx5_create_umr_qp_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_create_umr_qp"); ASSERT_EQ(nullptr, dlerror()); } ~test_tl_mlx5_rc_qp() { if (qp.qp) { ibv_destroy_qp(qp.qp); } if (umr_qp.qp) { ibv_destroy_qp(umr_qp.qp); } } void create_qp() { GTEST_ASSERT_EQ(create_rc_qp(ctx, pd, cq, tx_depth, &qp, &qpn, &lib), UCC_OK); } void create_umr_qp() { GTEST_ASSERT_EQ( create_rc_umr_qp(ctx, pd, cq, port, &umr_qp.qp, &umr_qp_conf, &lib), UCC_OK); } void connect_qp_loopback() { GTEST_ASSERT_EQ( qp_connect(qp.qp, qpn, port_attr.lid, port, &qp_conf, &lib), UCC_OK); }; }; typedef ucc_status_t (*ucc_tl_mlx5_init_dct_fn_t)( struct ibv_pd *pd, struct ibv_context *ctx, struct ibv_cq *cq, struct ibv_srq *srq, uint8_t port_num, struct ibv_qp **dct_qp, uint32_t *qpn, ucc_tl_mlx5_ib_qp_conf_t *qp_conf, ucc_base_lib_t *lib); typedef ucc_status_t (*ucc_tl_mlx5_init_dci_fn_t)( ucc_tl_mlx5_dci_t *dci, struct ibv_pd *pd, struct ibv_context *ctx, struct ibv_cq *cq, uint8_t port_num, int tx_depth, ucc_tl_mlx5_ib_qp_conf_t *qp_conf, ucc_base_lib_t *lib); typedef ucc_status_t (*ucc_tl_mlx5_create_ah_fn_t)(struct ibv_pd * pd, uint16_t lid, uint8_t port_num, struct ibv_ah **ah_ptr, ucc_base_lib_t *lib); class test_tl_mlx5_dc : public test_tl_mlx5_qp { public: struct ibv_qp *dct_qp = nullptr; struct ibv_ah *ah = nullptr; struct ibv_srq *srq = nullptr; ucc_tl_mlx5_dci_t dci = {}; uint32_t dct_qpn; ucc_tl_mlx5_init_dct_fn_t init_dct; ucc_tl_mlx5_init_dci_fn_t init_dci; ucc_tl_mlx5_create_ah_fn_t create_ah; virtual void SetUp() { struct ibv_srq_init_attr srq_attr; test_tl_mlx5_qp::SetUp(); CHECK_TEST_STATUS(); init_dct = (ucc_tl_mlx5_init_dct_fn_t)dlsym(tl_mlx5_so_handle, "ucc_tl_mlx5_init_dct"); ASSERT_EQ(nullptr, dlerror()); init_dci = (ucc_tl_mlx5_init_dci_fn_t)dlsym(tl_mlx5_so_handle, "ucc_tl_mlx5_init_dci"); ASSERT_EQ(nullptr, dlerror()); create_ah = (ucc_tl_mlx5_create_ah_fn_t)dlsym(tl_mlx5_so_handle, "ucc_tl_mlx5_create_ah"); ASSERT_EQ(nullptr, dlerror()); memset(&srq_attr, 0, sizeof(struct ibv_srq_init_attr)); srq_attr.attr.max_wr = 1; srq_attr.attr.max_sge = 1; srq = ibv_create_srq(pd, &srq_attr); EXPECT_NE(nullptr, srq); dct_qp = NULL; } ~test_tl_mlx5_dc() { if (ah) { ibv_destroy_ah(ah); } if (dct_qp) { ibv_destroy_qp(dct_qp); } if (dci.dci_qp) { ibv_destroy_qp(dci.dci_qp); } if (srq) { ibv_destroy_srq(srq); } } }; openucx-ucc-ec0bc8a/test/gtest/tl/mlx5/test_tl_mlx5_qps.cc0000664000175000017500000000165015133731560024132 0ustar alastairalastair/** * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_tl_mlx5_qps.h" UCC_TEST_F(test_tl_mlx5_rc_qp, create) { create_qp(); } UCC_TEST_F(test_tl_mlx5_rc_qp, create_umr) { create_umr_qp(); } UCC_TEST_F(test_tl_mlx5_rc_qp, connect_loopback) { create_qp(); CHECK_TEST_STATUS(); connect_qp_loopback(); } UCC_TEST_F(test_tl_mlx5_dc, init_dct) { ucc_status_t status; status = init_dct(pd, ctx, cq, srq, port, &dct_qp, &dct_qpn, &qp_conf, &lib); GTEST_ASSERT_EQ(UCC_OK, status); } UCC_TEST_F(test_tl_mlx5_dc, init_dci) { ucc_status_t status; status = init_dci(&dci, pd, ctx, cq, port, 4, &qp_conf, &lib); GTEST_ASSERT_EQ(UCC_OK, status); } UCC_TEST_F(test_tl_mlx5_dc, create_ah) { ucc_status_t status; status = create_ah(pd, port_attr.lid, port, &ah, &lib); GTEST_ASSERT_EQ(UCC_OK, status); } openucx-ucc-ec0bc8a/test/gtest/tl/mlx5/test_tl_mlx5.cc0000664000175000017500000000342715133731560023253 0ustar alastairalastair/** * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_tl_mlx5.h" test_tl_mlx5::test_tl_mlx5() { tl_mlx5_so_handle = NULL; ctx = NULL; pd = NULL; cq = NULL; } void test_tl_mlx5::SetUp() { char * devname = NULL; ucc_status_t status; ASSERT_EQ(UCC_OK, ucc_constructor()); ucc_strncpy_safe(lib.log_component.name, "GTEST_MLX5", sizeof(lib.log_component.name)); lib.log_component.log_level = UCC_LOG_LEVEL_ERROR; std::string path = std::string(ucc_global_config.component_path) + "/libucc_tl_mlx5.so"; tl_mlx5_so_handle = dlopen(path.c_str(), RTLD_NOW); if (!tl_mlx5_so_handle) { GTEST_SKIP() << "cannot open ucc_tl_mlx5 library" ; } create_ibv_ctx = (ucc_tl_mlx5_create_ibv_ctx_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_create_ibv_ctx"); ASSERT_EQ(nullptr, dlerror()); get_active_port = (ucc_tl_mlx5_get_active_port_fn_t)dlsym( tl_mlx5_so_handle, "ucc_tl_mlx5_get_active_port"); ASSERT_EQ(nullptr, dlerror()); status = create_ibv_ctx(&devname, &ctx, &lib); if (UCC_OK != status) { GTEST_SKIP() << "no ib devices"; } port = get_active_port(ctx); ASSERT_GE(port, 0); ASSERT_EQ(ibv_query_port(ctx, port, &port_attr), 0); pd = ibv_alloc_pd(ctx); ASSERT_NE(nullptr, pd); cq = ibv_create_cq(ctx, 8, NULL, NULL, 0); ASSERT_NE(nullptr, cq); } test_tl_mlx5::~test_tl_mlx5() { if (cq) { ibv_destroy_cq(cq); } if (pd) { ibv_dealloc_pd(pd); } if (ctx) { ibv_close_device(ctx); } if (tl_mlx5_so_handle) { dlclose(tl_mlx5_so_handle); } } openucx-ucc-ec0bc8a/test/gtest/tl/tl_test.cc0000664000175000017500000000034115133731560021411 0ustar alastairalastair/** * Copyright (C) Huawei Technologies Co., Ltd. 2020. All rights reserved. * See file LICENSE for terms. */ #include class test_tl : public ucc::test {}; UCC_TEST_F(test_tl, dummy_test) { } openucx-ucc-ec0bc8a/test/gtest/utils/0000775000175000017500000000000015133731560020147 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/utils/test_lock_free_queue.cc0000664000175000017500000000622615133731560024660 0ustar alastairalastair /** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ extern "C" { #include "utils/ucc_lock_free_queue.h" #include "utils/ucc_atomic.h" #include "utils/ucc_malloc.h" #include #include } #include #include #define NUM_ITERS 5000000 typedef struct ucc_test_queue { ucc_lf_queue_t lf_queue; int64_t test_sum; uint32_t elems_num; uint32_t active_producers_threads; uint32_t memory_err; } ucc_test_queue_t; void *producer_thread(void *arg) { ucc_test_queue_t *test = (ucc_test_queue_t *)arg; for (int j = 0; j < NUM_ITERS; j++) { ucc_lf_queue_elem_t *elem = (ucc_lf_queue_elem_t *)ucc_malloc(sizeof(ucc_lf_queue_elem_t)); ucc_lf_queue_init_elem(elem); if (!elem) { ucc_atomic_add32(&test->memory_err, 1); goto exit; } ucc_lf_queue_enqueue(&test->lf_queue, elem); ucc_atomic_add64((uint64_t *)&test->test_sum, (uint64_t)elem); ucc_atomic_add32(&test->elems_num, 1); } exit: ucc_atomic_sub32(&test->active_producers_threads,1); return 0; } void *consumer_thread(void *arg) { ucc_test_queue_t *test = (ucc_test_queue_t *)arg; while(test->active_producers_threads || test->elems_num){ ucc_lf_queue_elem_t *elem = ucc_lf_queue_dequeue(&test->lf_queue, 1); if (elem) { ucc_atomic_sub64((uint64_t *)&test->test_sum, (uint64_t)elem); ucc_atomic_sub32(&test->elems_num, 1); ucc_free(elem); } } return 0; } class test_lf_queue : public ucc::test { public: ucc_test_queue_t test; int i; std::vector producers_threads; std::vector consumers_threads; int lf_test(int num_of_producers, int num_of_consumers); }; int test_lf_queue::lf_test(int num_of_producers, int num_of_consumers){ producers_threads.resize(num_of_producers); consumers_threads.resize(num_of_consumers); memset(&test, 0, sizeof(ucc_test_queue_t)); ucc_lf_queue_init(&test.lf_queue); for (i = 0; i < num_of_producers; i++) { ucc_atomic_add32(&test.active_producers_threads, 1); pthread_create(&producers_threads[i], NULL, &producer_thread, (void *)&test); } for (i = 0; i < num_of_consumers; i++) { pthread_create(&consumers_threads[i], NULL, &consumer_thread, (void *)&test); } for (i = 0; i < num_of_producers; i++) { pthread_join(producers_threads[i], NULL); } for (i = 0; i < num_of_consumers; i++) { pthread_join(consumers_threads[i], NULL); } ucc_lf_queue_destroy(&test.lf_queue); if (test.memory_err) { return 1; } if (test.test_sum) { return 1; } return 0; } UCC_TEST_F(test_lf_queue, oneProducerOneConsumer) { EXPECT_EQ(lf_test(1, 1), 0); } UCC_TEST_F(test_lf_queue, oneProducerManyConsumers) { EXPECT_EQ(lf_test(1, 7), 0); } UCC_TEST_F(test_lf_queue, manyProducersManyConsumers) { EXPECT_EQ(lf_test(7, 7), 0); } openucx-ucc-ec0bc8a/test/gtest/utils/test_ep_map.cc0000664000175000017500000001063715133731560022765 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ extern "C" { #include "utils/ucc_coll_utils.h" } #include #include #include class EpMap { public: ucc_ep_map_t map; ucc_rank_t * array; EpMap(){}; EpMap(ucc_ep_map_t _map) { array = NULL; map = _map; }; EpMap(uint64_t full, bool reverse = false) { array = NULL; if (reverse) { map = ucc_ep_map_create_reverse(full); } else { map.type = UCC_EP_MAP_FULL; map.ep_num = full; } }; EpMap(uint64_t start, int64_t stride, uint64_t num, uint64_t full) { map.type = ((num == full) && (stride == 1)) ? UCC_EP_MAP_FULL : UCC_EP_MAP_STRIDED; map.ep_num = num; map.strided.start = start; map.strided.stride = stride; array = NULL; }; EpMap(const std::vector &ranks, uint64_t full, int need_free = 0) { array = (ucc_rank_t*)malloc(sizeof(ucc_rank_t) * ranks.size()); memcpy(array, ranks.data(), ranks.size() * sizeof(ucc_rank_t)); map = ucc_ep_map_from_array(&array, ranks.size(), full, need_free); } ~EpMap() { if (array) { free(array); } } friend bool operator==(const EpMap &lhs, const EpMap &rhs) { if ((lhs.map.type != rhs.map.type) || (lhs.map.ep_num != rhs.map.ep_num)) { return false; } switch(lhs.map.type) { case UCC_EP_MAP_FULL: return true; case UCC_EP_MAP_STRIDED: return (lhs.map.strided.start == rhs.map.strided.start) && (lhs.map.strided.stride == rhs.map.strided.stride); default: break; } return false; } }; class test_ep_map : public ucc::test {}; UCC_TEST_F(test_ep_map, from_array) { // Full contiguous map EXPECT_EQ(EpMap({1,2,3,4,5,6,7,8,9,10}, 10), EpMap(10)); // Strided contiguous map EXPECT_EQ(EpMap({1,31,61,91,121}, 150), EpMap(1, 30, 5, 150)); // Strided negative EXPECT_EQ(EpMap({100,90,80,70,60}, 150), EpMap(100, -10, 5, 150)); } UCC_TEST_F(test_ep_map, from_array_free) { /* strided pattern not found - array is not released */ EXPECT_NE((void*)NULL, EpMap({1, 5, 6, 8, 11}, 10, 1).array); /* strided pattern found - array is released */ EXPECT_EQ((void*)NULL, EpMap({2, 4, 6, 8, 10}, 10, 1).array); /* FULL pattern found - array is released */ EXPECT_EQ((void*)NULL, EpMap({1, 2, 3, 4, 5}, 5, 1).array); } UCC_TEST_F(test_ep_map, reverse) { const int size = 10; ucc_ep_map_t map = ucc_ep_map_create_reverse(size); for (int i = 0; i < size; i++) { EXPECT_EQ(size - 1 - i, ucc_ep_map_eval(map, i)); } } UCC_TEST_F(test_ep_map, nested) { auto map1 = EpMap(100); //full map, size 100 auto map2 = EpMap(0, 2, 50, 100); // submap even only auto map3 = EpMap(1, 2, 25, 50); // submap odd only from 50 ucc_ep_map_t nested1, nested2; EXPECT_EQ(UCC_OK, ucc_ep_map_create_nested(&map1.map, &map2.map, &nested1)); EXPECT_EQ(50, nested1.ep_num); for (int i = 0; i < nested1.ep_num; i++) { EXPECT_EQ(0 + i * 2, ucc_ep_map_eval(nested1, i)); } EXPECT_EQ(UCC_OK, ucc_ep_map_create_nested(&nested1, &map3.map, &nested2)); EXPECT_EQ(25, nested2.ep_num); for (int i = 0; i < nested2.ep_num; i++) { EXPECT_EQ(2 + i * 4, ucc_ep_map_eval(nested2, i)); } ucc_ep_map_destroy_nested(&nested1); ucc_ep_map_destroy_nested(&nested2); } class test_ep_map_inv : public test_ep_map { public: void check_inv(EpMap map) { ucc_ep_map_t inv; EXPECT_EQ(UCC_OK, ucc_ep_map_create_inverse(map.map, &inv, 0)); for (int i = 0; i < map.map.ep_num; i++) { EXPECT_EQ(i, ucc_ep_map_eval(inv, ucc_ep_map_eval(map.map, i))); } ucc_ep_map_destroy(&inv); }; }; UCC_TEST_F(test_ep_map_inv, contig) { /* reverse of FULL */ check_inv(EpMap(10)); /* reverse of INVERSE */ check_inv(EpMap(10, true)); } UCC_TEST_F(test_ep_map_inv, strided) { /* stride positive */ check_inv(EpMap(1, 30, 5, 150)); /* stride negative */ check_inv(EpMap(100, -10, 5, 150)); } UCC_TEST_F(test_ep_map_inv, random) { check_inv(EpMap({4, 0, 1, 2, 3}, 5)); } openucx-ucc-ec0bc8a/test/gtest/utils/test_math.cc0000664000175000017500000000233415133731560022450 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ extern "C" { #include "utils/ucc_math.h" } #include using floatParams = float; class test_floats_cast : public ucc::test, public ::testing::WithParamInterface { }; UCC_TEST_P(test_floats_cast, start_with_float) { float p = GetParam(); uint16_t bfloat16Val; float32tobfloat16(p, &bfloat16Val); // All values in test_floats_cast use only first 16 bits of the 32. rest are 0. EXPECT_EQ(p, bfloat16tofloat32(&bfloat16Val)); } INSTANTIATE_TEST_CASE_P(, test_floats_cast, ::testing::Values(-4.26941244675e+18, -6.95441104568e+13, 2.015625)); using bfloat16Params = uint16_t; class test_bfloats16_cast : public ucc::test, public ::testing::WithParamInterface { }; UCC_TEST_P(test_bfloats16_cast, start_with_bfloat16) { uint16_t p = GetParam(); uint16_t res; float32tobfloat16(bfloat16tofloat32(&p), &res); EXPECT_EQ(p, res); } INSTANTIATE_TEST_CASE_P(, test_bfloats16_cast, ::testing::Values(31000, 400, 17, 13569, 0)); openucx-ucc-ec0bc8a/test/gtest/utils/test_parser.cc0000664000175000017500000000460615133731560023017 0ustar alastairalastair/** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ extern "C" { #include "utils/ucc_parser.h" #include "utils/ucc_datastruct.h" } #include #include class test_parse_mrange : public ucc::test { public: ucc_mrange_uint_t *p; test_parse_mrange() { p = (ucc_mrange_uint_t *) ucc_malloc(sizeof(ucc_mrange_uint_t)); } ~test_parse_mrange() { ucc_free(p); } }; UCC_TEST_F(test_parse_mrange, check_valid) { std::string str = "0-4K:host:8,auto"; size_t msgsize1 = 1024, msgsize2 = 8192; EXPECT_EQ(1, ucc_config_sscanf_uint_ranged(str.c_str(), p, NULL)); EXPECT_EQ(8, ucc_mrange_uint_get(p, msgsize1, UCC_MEMORY_TYPE_HOST)); EXPECT_EQ(UCC_UUNITS_AUTO, ucc_mrange_uint_get(p, msgsize2, UCC_MEMORY_TYPE_HOST)); ucc_mrange_uint_destroy(p); } UCC_TEST_F(test_parse_mrange, check_invalid) { std::string str = "0-4K:host:8:8"; EXPECT_EQ(0, ucc_config_sscanf_uint_ranged(str.c_str(), p, NULL)); ucc_mrange_uint_destroy(p); str = "0-4K:host:a"; EXPECT_EQ(0, ucc_config_sscanf_uint_ranged(str.c_str(), p, NULL)); ucc_mrange_uint_destroy(p); str = "0-4K:gpu:8"; EXPECT_EQ(0, ucc_config_sscanf_uint_ranged(str.c_str(), p, NULL)); ucc_mrange_uint_destroy(p); str = "0-f:host:8"; EXPECT_EQ(0, ucc_config_sscanf_uint_ranged(str.c_str(), p, NULL)); ucc_mrange_uint_destroy(p); } UCC_TEST_F(test_parse_mrange, check_range_multiple) { std::string str = "0-4K:host:8,4k-inf:host:10,0-4k:cuda:7,0-4k:cuda_managed:6,auto"; size_t msgsize1 = 1024, msgsize2 = 8192; EXPECT_EQ(1, ucc_config_sscanf_uint_ranged(str.c_str(), p, NULL)); EXPECT_EQ(8, ucc_mrange_uint_get(p, msgsize1, UCC_MEMORY_TYPE_HOST)); EXPECT_EQ(10, ucc_mrange_uint_get(p, msgsize2, UCC_MEMORY_TYPE_HOST)); EXPECT_EQ(7, ucc_mrange_uint_get(p, msgsize1, UCC_MEMORY_TYPE_CUDA)); EXPECT_EQ(UCC_UUNITS_AUTO, ucc_mrange_uint_get(p, msgsize2, UCC_MEMORY_TYPE_CUDA)); EXPECT_EQ(6, ucc_mrange_uint_get(p, msgsize1, UCC_MEMORY_TYPE_CUDA_MANAGED)); EXPECT_EQ(UCC_UUNITS_AUTO, ucc_mrange_uint_get(p, msgsize2, UCC_MEMORY_TYPE_CUDA_MANAGED)); ucc_mrange_uint_destroy(p); } openucx-ucc-ec0bc8a/test/gtest/utils/test_cfg_file.cc0000664000175000017500000000633015133731560023255 0ustar alastairalastair/** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ extern "C" { #include "utils/ucc_parser.h" #include "core/ucc_global_opts.h" } #include typedef struct ucc_gtest_config_base { int foo; } ucc_gtest_config_base_t; typedef struct ucc_gtest_config { ucc_gtest_config_base_t super; int bar; int boo; } ucc_gtest_config_t; static ucc_config_field_t ucc_gtest_config_table_base[] = { {"FOO", "1", "gtest base config variable", ucc_offsetof(ucc_gtest_config_base_t, foo), UCC_CONFIG_TYPE_INT}, {NULL} }; static ucc_config_field_t ucc_gtest_config_table[] = { {"", "", NULL, ucc_offsetof(ucc_gtest_config_t, super), UCC_CONFIG_TYPE_TABLE(ucc_gtest_config_table_base)}, {"BAR", "1", "gtest config variable", ucc_offsetof(ucc_gtest_config_t, bar), UCC_CONFIG_TYPE_INT}, {"BOO", "1", "gtest config variable", ucc_offsetof(ucc_gtest_config_t, boo), UCC_CONFIG_TYPE_INT}, {NULL} }; UCC_CONFIG_DECLARE_TABLE(ucc_gtest_config_table, "GTEST TABLE", "CFG_", ucc_gtest_config_t); class test_cfg_file : public ucc::test { public: ucc_file_config_t *file_cfg; ucc_gtest_config_t cfg; std::string test_dir; test_cfg_file() { file_cfg = NULL; test_dir = std::string(GTEST_UCC_TOP_SRCDIR) + "/test/gtest/utils/"; } ~test_cfg_file() { if (file_cfg) { ucc_release_file_config(file_cfg); } ucc_config_parser_release_opts(&cfg, ucc_gtest_config_table); } void init_cfg() { ucc_status_t status; std::swap(ucc_global_config.file_cfg, file_cfg); status = ucc_config_parser_fill_opts( &cfg, UCC_CONFIG_GET_TABLE(ucc_gtest_config_table), "GTEST_UCC_", 1); std::swap(ucc_global_config.file_cfg, file_cfg); EXPECT_EQ(UCC_OK, status); }; }; UCC_TEST_F(test_cfg_file, parse_existing) { std::string filename = test_dir + "ucc_test.conf"; EXPECT_EQ(UCC_OK, ucc_parse_file_config(filename.c_str(), &file_cfg)); } UCC_TEST_F(test_cfg_file, parse_non_existing) { std::string filename = test_dir + "ucc_test_nonexisting.conf"; EXPECT_EQ(UCC_ERR_NOT_FOUND, ucc_parse_file_config(filename.c_str(), &file_cfg)); } /* Checks options are applied from cfg file */ UCC_TEST_F(test_cfg_file, opts_applied) { std::string filename = test_dir + "ucc_test.conf"; EXPECT_EQ(UCC_OK, ucc_parse_file_config(filename.c_str(), &file_cfg)); init_cfg(); EXPECT_EQ(10, cfg.super.foo); EXPECT_EQ(20, cfg.bar); EXPECT_EQ(1, cfg.boo); } /* Checks that options set via env var have preference over cfg file */ UCC_TEST_F(test_cfg_file, env_preference) { std::string filename = test_dir + "ucc_test.conf"; setenv("GTEST_UCC_CFG_BAR", "123", 1); EXPECT_EQ(UCC_OK, ucc_parse_file_config(filename.c_str(), &file_cfg)); init_cfg(); unsetenv("GTEST_UCC_CFG_BAR"); EXPECT_EQ(10, cfg.super.foo); /* Expected value is 123 from env rather than 20 from file */ EXPECT_EQ(123, cfg.bar); EXPECT_EQ(1, cfg.boo); } openucx-ucc-ec0bc8a/test/gtest/utils/test_string.cc0000664000175000017500000000543415133731560023031 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ extern "C" { #include "utils/ucc_string.h" #include "utils/ucc_malloc.h" } #include #include #include using Param = std::vector; class test_string : public ucc::test, public ::testing::WithParamInterface { public: std::string str; void make_str(Param &p); }; void test_string::make_str(Param &p) { unsigned i; str = ""; for (i = 1; i < p.size() - 1; i++) { str += p[i] + p[0]; } str += p[i]; } UCC_TEST_P(test_string, split) { auto p = GetParam(); const char *delim = p[0].c_str(); make_str(p); char **split = ucc_str_split(str.c_str(), delim); EXPECT_NE(nullptr, split); EXPECT_EQ(ucc_str_split_count(split), p.size() - 1); for (auto i = 0; i < ucc_str_split_count(split); i++) { EXPECT_EQ(p[i + 1], split[i]); } ucc_str_split_free(split); } INSTANTIATE_TEST_CASE_P( , test_string, ::testing::Values(std::vector({",", "aaa", "bbb", "ccc"}), std::vector({":", "a", "b", "c", "d", "e"}), std::vector({" ", "a", "bb"}), std::vector({"...", "aaaaaa", "b", "ccccc"}))); UCC_TEST_F(test_string, split_no_delim) { std::string delim = ":"; std::string s = "string with no delimiter"; char ** split = ucc_str_split(s.c_str(), delim.c_str()); EXPECT_NE(nullptr, split); EXPECT_EQ(ucc_str_split_count(split), 1); EXPECT_EQ(s, split[0]); ucc_str_split_free(split); } UCC_TEST_F(test_string, find_last) { const char* str1 = "aaa bbb aaa ccc aaa ddd"; const char* str2 = "/tmp/build/ucc_tl_ucp/install/lib/ucc_tl_ucp.so"; const char *s; s = ucc_strstr_last(str1, "aaa"); EXPECT_EQ(16, (ptrdiff_t)s - (ptrdiff_t)str1); s = ucc_strstr_last(str2, "ucc_tl_ucp"); EXPECT_EQ(strlen(str2) - strlen("ucc_tl_ucp.so"), (ptrdiff_t)s - (ptrdiff_t)str2); s = ucc_strstr_last(str1, "fff"); EXPECT_EQ(NULL, s); s = ucc_strstr_last(str2, "/tmp"); EXPECT_EQ(str2, s); } UCC_TEST_F(test_string, concat) { char *rst; EXPECT_EQ(UCC_OK, ucc_str_concat("aaa", "bbb", &rst)); EXPECT_EQ("aaabbb", std::string(rst)); ucc_free(rst); EXPECT_EQ(UCC_OK, ucc_str_concat("aaabbbccc", "d", &rst)); EXPECT_EQ("aaabbbcccd", std::string(rst)); ucc_free(rst); EXPECT_EQ(UCC_OK, ucc_str_concat("aaa", "", &rst)); EXPECT_EQ("aaa", std::string(rst)); ucc_free(rst); EXPECT_EQ(UCC_OK, ucc_str_concat("", "aaa", &rst)); EXPECT_EQ("aaa", std::string(rst)); ucc_free(rst); } openucx-ucc-ec0bc8a/test/gtest/utils/ucc_test.conf0000664000175000017500000000032515133731560022627 0ustar alastairalastair# Some random comment UCC_FOO = 10 ## Tests setting global variable GTEST_UCC_CFG_BAR = 20 ;; test setting var with ENV_PREFIX = GTEST #variable with undefined ENV_PREFIX - will be ignored GGGTEST_UCC_BOO = 30 openucx-ucc-ec0bc8a/test/gtest/common/0000775000175000017500000000000015133731560020277 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/common/gtest-all.cc0000664000175000017500000152007715133731560022516 0ustar alastairalastair// Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Google C++ Testing and Mocking Framework (Google Test) // // Sometimes it's desirable to build Google Test by compiling a single file. // This file serves this purpose. // This line ensures that gtest.h can be compiled on its own, even // when it's fused. #include "gtest.h" // The following lines pull in the real gtest *.cc files. // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // Copyright 2007, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Utilities for testing Google Test itself and code that uses Google Test // (e.g. frameworks built on top of Google Test). // GOOGLETEST_CM0004 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_SPI_H_ #define GTEST_INCLUDE_GTEST_GTEST_SPI_H_ GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) namespace testing { // This helper class can be used to mock out Google Test failure reporting // so that we can test Google Test or code that builds on Google Test. // // An object of this class appends a TestPartResult object to the // TestPartResultArray object given in the constructor whenever a Google Test // failure is reported. It can either intercept only failures that are // generated in the same thread that created this object or it can intercept // all generated failures. The scope of this mock object can be controlled with // the second argument to the two arguments constructor. class GTEST_API_ ScopedFakeTestPartResultReporter : public TestPartResultReporterInterface { public: // The two possible mocking modes of this object. enum InterceptMode { INTERCEPT_ONLY_CURRENT_THREAD, // Intercepts only thread local failures. INTERCEPT_ALL_THREADS // Intercepts all failures. }; // The c'tor sets this object as the test part result reporter used // by Google Test. The 'result' parameter specifies where to report the // results. This reporter will only catch failures generated in the current // thread. DEPRECATED explicit ScopedFakeTestPartResultReporter(TestPartResultArray* result); // Same as above, but you can choose the interception scope of this object. ScopedFakeTestPartResultReporter(InterceptMode intercept_mode, TestPartResultArray* result); // The d'tor restores the previous test part result reporter. ~ScopedFakeTestPartResultReporter() override; // Appends the TestPartResult object to the TestPartResultArray // received in the constructor. // // This method is from the TestPartResultReporterInterface // interface. void ReportTestPartResult(const TestPartResult& result) override; private: void Init(); const InterceptMode intercept_mode_; TestPartResultReporterInterface* old_reporter_; TestPartResultArray* const result_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedFakeTestPartResultReporter); }; namespace internal { // A helper class for implementing EXPECT_FATAL_FAILURE() and // EXPECT_NONFATAL_FAILURE(). Its destructor verifies that the given // TestPartResultArray contains exactly one failure that has the given // type and contains the given substring. If that's not the case, a // non-fatal failure will be generated. class GTEST_API_ SingleFailureChecker { public: // The constructor remembers the arguments. SingleFailureChecker(const TestPartResultArray* results, TestPartResult::Type type, const std::string& substr); ~SingleFailureChecker(); private: const TestPartResultArray* const results_; const TestPartResult::Type type_; const std::string substr_; GTEST_DISALLOW_COPY_AND_ASSIGN_(SingleFailureChecker); }; } // namespace internal } // namespace testing GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 // A set of macros for testing Google Test assertions or code that's expected // to generate Google Test fatal failures. It verifies that the given // statement will cause exactly one fatal Google Test failure with 'substr' // being part of the failure message. // // There are two different versions of this macro. EXPECT_FATAL_FAILURE only // affects and considers failures generated in the current thread and // EXPECT_FATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. // // The verification of the assertion is done correctly even when the statement // throws an exception or aborts the current function. // // Known restrictions: // - 'statement' cannot reference local non-static variables or // non-static members of the current object. // - 'statement' cannot return a value. // - You cannot stream a failure message to this macro. // // Note that even though the implementations of the following two // macros are much alike, we cannot refactor them to use a common // helper macro, due to some peculiarity in how the preprocessor // works. The AcceptsMacroThatExpandsToUnprotectedComma test in // gtest_unittest.cc will fail to compile if we do that. #define EXPECT_FATAL_FAILURE(statement, substr) \ do { \ class GTestExpectFatalFailureHelper {\ public:\ static void Execute() { statement; }\ };\ ::testing::TestPartResultArray gtest_failures;\ ::testing::internal::SingleFailureChecker gtest_checker(\ >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ {\ ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ ::testing::ScopedFakeTestPartResultReporter:: \ INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ GTestExpectFatalFailureHelper::Execute();\ }\ } while (::testing::internal::AlwaysFalse()) #define EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ do { \ class GTestExpectFatalFailureHelper {\ public:\ static void Execute() { statement; }\ };\ ::testing::TestPartResultArray gtest_failures;\ ::testing::internal::SingleFailureChecker gtest_checker(\ >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ {\ ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ ::testing::ScopedFakeTestPartResultReporter:: \ INTERCEPT_ALL_THREADS, >est_failures);\ GTestExpectFatalFailureHelper::Execute();\ }\ } while (::testing::internal::AlwaysFalse()) // A macro for testing Google Test assertions or code that's expected to // generate Google Test non-fatal failures. It asserts that the given // statement will cause exactly one non-fatal Google Test failure with 'substr' // being part of the failure message. // // There are two different versions of this macro. EXPECT_NONFATAL_FAILURE only // affects and considers failures generated in the current thread and // EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. // // 'statement' is allowed to reference local variables and members of // the current object. // // The verification of the assertion is done correctly even when the statement // throws an exception or aborts the current function. // // Known restrictions: // - You cannot stream a failure message to this macro. // // Note that even though the implementations of the following two // macros are much alike, we cannot refactor them to use a common // helper macro, due to some peculiarity in how the preprocessor // works. If we do that, the code won't compile when the user gives // EXPECT_NONFATAL_FAILURE() a statement that contains a macro that // expands to code containing an unprotected comma. The // AcceptsMacroThatExpandsToUnprotectedComma test in gtest_unittest.cc // catches that. // // For the same reason, we have to write // if (::testing::internal::AlwaysTrue()) { statement; } // instead of // GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) // to avoid an MSVC warning on unreachable code. #define EXPECT_NONFATAL_FAILURE(statement, substr) \ do {\ ::testing::TestPartResultArray gtest_failures;\ ::testing::internal::SingleFailureChecker gtest_checker(\ >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ (substr));\ {\ ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ ::testing::ScopedFakeTestPartResultReporter:: \ INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ if (::testing::internal::AlwaysTrue()) { statement; }\ }\ } while (::testing::internal::AlwaysFalse()) #define EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ do {\ ::testing::TestPartResultArray gtest_failures;\ ::testing::internal::SingleFailureChecker gtest_checker(\ >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ (substr));\ {\ ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ ::testing::ScopedFakeTestPartResultReporter::INTERCEPT_ALL_THREADS, \ >est_failures);\ if (::testing::internal::AlwaysTrue()) { statement; }\ }\ } while (::testing::internal::AlwaysFalse()) #endif // GTEST_INCLUDE_GTEST_GTEST_SPI_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include // NOLINT #include #include #if GTEST_OS_LINUX # define GTEST_HAS_GETTIMEOFDAY_ 1 # include // NOLINT # include // NOLINT # include // NOLINT // Declares vsnprintf(). This header is not available on Windows. # include // NOLINT # include // NOLINT # include // NOLINT # include // NOLINT # include #elif GTEST_OS_ZOS # define GTEST_HAS_GETTIMEOFDAY_ 1 # include // NOLINT // On z/OS we additionally need strings.h for strcasecmp. # include // NOLINT #elif GTEST_OS_WINDOWS_MOBILE // We are on Windows CE. # include // NOLINT # undef min #elif GTEST_OS_WINDOWS // We are on Windows proper. # include // NOLINT # undef min # include // NOLINT # include // NOLINT # include // NOLINT # include // NOLINT # include // NOLINT # include // NOLINT # if GTEST_OS_WINDOWS_MINGW // MinGW has gettimeofday() but not _ftime64(). # define GTEST_HAS_GETTIMEOFDAY_ 1 # include // NOLINT # endif // GTEST_OS_WINDOWS_MINGW #else // Assume other platforms have gettimeofday(). # define GTEST_HAS_GETTIMEOFDAY_ 1 // cpplint thinks that the header is already included, so we want to // silence it. # include // NOLINT # include // NOLINT #endif // GTEST_OS_LINUX #if GTEST_HAS_EXCEPTIONS # include #endif #if GTEST_CAN_STREAM_RESULTS_ # include // NOLINT # include // NOLINT # include // NOLINT # include // NOLINT #endif // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Utility functions and classes used by the Google C++ testing framework.// // This file contains purely Google Test's internal implementation. Please // DO NOT #INCLUDE IT IN A USER PROGRAM. #ifndef GTEST_SRC_GTEST_INTERNAL_INL_H_ #define GTEST_SRC_GTEST_INTERNAL_INL_H_ #ifndef _WIN32_WCE # include #endif // !_WIN32_WCE #include #include // For strtoll/_strtoul64/malloc/free. #include // For memmove. #include #include #include #include #if GTEST_CAN_STREAM_RESULTS_ # include // NOLINT # include // NOLINT #endif #if GTEST_OS_WINDOWS # include // NOLINT #endif // GTEST_OS_WINDOWS GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) namespace testing { // Declares the flags. // // We don't want the users to modify this flag in the code, but want // Google Test's own unit tests to be able to access it. Therefore we // declare it here as opposed to in gtest.h. GTEST_DECLARE_bool_(death_test_use_fork); namespace internal { // The value of GetTestTypeId() as seen from within the Google Test // library. This is solely for testing GetTestTypeId(). GTEST_API_ extern const TypeId kTestTypeIdInGoogleTest; // Names of the flags (needed for parsing Google Test flags). const char kAlsoRunDisabledTestsFlag[] = "also_run_disabled_tests"; const char kBreakOnFailureFlag[] = "break_on_failure"; const char kCatchExceptionsFlag[] = "catch_exceptions"; const char kColorFlag[] = "color"; const char kFilterFlag[] = "filter"; const char kListTestsFlag[] = "list_tests"; const char kOutputFlag[] = "output"; const char kPrintTimeFlag[] = "print_time"; const char kPrintUTF8Flag[] = "print_utf8"; const char kRandomSeedFlag[] = "random_seed"; const char kRepeatFlag[] = "repeat"; const char kShuffleFlag[] = "shuffle"; const char kStackTraceDepthFlag[] = "stack_trace_depth"; const char kStreamResultToFlag[] = "stream_result_to"; const char kThrowOnFailureFlag[] = "throw_on_failure"; const char kFlagfileFlag[] = "flagfile"; const char kPrintSkippedFlag[] = "print_skipped"; // A valid random seed must be in [1, kMaxRandomSeed]. const int kMaxRandomSeed = 99999; // g_help_flag is true if and only if the --help flag or an equivalent form // is specified on the command line. GTEST_API_ extern bool g_help_flag; // Returns the current time in milliseconds. GTEST_API_ TimeInMillis GetTimeInMillis(); // Returns true if and only if Google Test should use colors in the output. GTEST_API_ bool ShouldUseColor(bool stdout_is_tty); // Formats the given time in milliseconds as seconds. GTEST_API_ std::string FormatTimeInMillisAsSeconds(TimeInMillis ms); // Converts the given time in milliseconds to a date string in the ISO 8601 // format, without the timezone information. N.B.: due to the use the // non-reentrant localtime() function, this function is not thread safe. Do // not use it in any code that can be called from multiple threads. GTEST_API_ std::string FormatEpochTimeInMillisAsIso8601(TimeInMillis ms); // Parses a string for an Int32 flag, in the form of "--flag=value". // // On success, stores the value of the flag in *value, and returns // true. On failure, returns false without changing *value. GTEST_API_ bool ParseInt32Flag( const char* str, const char* flag, Int32* value); // Returns a random seed in range [1, kMaxRandomSeed] based on the // given --gtest_random_seed flag value. inline int GetRandomSeedFromFlag(Int32 random_seed_flag) { const unsigned int raw_seed = (random_seed_flag == 0) ? static_cast(GetTimeInMillis()) : static_cast(random_seed_flag); // Normalizes the actual seed to range [1, kMaxRandomSeed] such that // it's easy to type. const int normalized_seed = static_cast((raw_seed - 1U) % static_cast(kMaxRandomSeed)) + 1; return normalized_seed; } // Returns the first valid random seed after 'seed'. The behavior is // undefined if 'seed' is invalid. The seed after kMaxRandomSeed is // considered to be 1. inline int GetNextRandomSeed(int seed) { GTEST_CHECK_(1 <= seed && seed <= kMaxRandomSeed) << "Invalid random seed " << seed << " - must be in [1, " << kMaxRandomSeed << "]."; const int next_seed = seed + 1; return (next_seed > kMaxRandomSeed) ? 1 : next_seed; } // This class saves the values of all Google Test flags in its c'tor, and // restores them in its d'tor. class GTestFlagSaver { public: // The c'tor. GTestFlagSaver() { also_run_disabled_tests_ = GTEST_FLAG(also_run_disabled_tests); break_on_failure_ = GTEST_FLAG(break_on_failure); catch_exceptions_ = GTEST_FLAG(catch_exceptions); color_ = GTEST_FLAG(color); death_test_style_ = GTEST_FLAG(death_test_style); death_test_use_fork_ = GTEST_FLAG(death_test_use_fork); filter_ = GTEST_FLAG(filter); internal_run_death_test_ = GTEST_FLAG(internal_run_death_test); list_tests_ = GTEST_FLAG(list_tests); output_ = GTEST_FLAG(output); print_time_ = GTEST_FLAG(print_time); print_utf8_ = GTEST_FLAG(print_utf8); random_seed_ = GTEST_FLAG(random_seed); repeat_ = GTEST_FLAG(repeat); shuffle_ = GTEST_FLAG(shuffle); stack_trace_depth_ = GTEST_FLAG(stack_trace_depth); stream_result_to_ = GTEST_FLAG(stream_result_to); throw_on_failure_ = GTEST_FLAG(throw_on_failure); print_skipped_ = GTEST_FLAG(print_skipped); } // The d'tor is not virtual. DO NOT INHERIT FROM THIS CLASS. ~GTestFlagSaver() { GTEST_FLAG(also_run_disabled_tests) = also_run_disabled_tests_; GTEST_FLAG(break_on_failure) = break_on_failure_; GTEST_FLAG(catch_exceptions) = catch_exceptions_; GTEST_FLAG(color) = color_; GTEST_FLAG(death_test_style) = death_test_style_; GTEST_FLAG(death_test_use_fork) = death_test_use_fork_; GTEST_FLAG(filter) = filter_; GTEST_FLAG(internal_run_death_test) = internal_run_death_test_; GTEST_FLAG(list_tests) = list_tests_; GTEST_FLAG(output) = output_; GTEST_FLAG(print_time) = print_time_; GTEST_FLAG(print_utf8) = print_utf8_; GTEST_FLAG(random_seed) = random_seed_; GTEST_FLAG(repeat) = repeat_; GTEST_FLAG(shuffle) = shuffle_; GTEST_FLAG(stack_trace_depth) = stack_trace_depth_; GTEST_FLAG(stream_result_to) = stream_result_to_; GTEST_FLAG(throw_on_failure) = throw_on_failure_; GTEST_FLAG(print_skipped) = print_skipped_; } private: // Fields for saving the original values of flags. bool also_run_disabled_tests_; bool break_on_failure_; bool catch_exceptions_; std::string color_; std::string death_test_style_; bool death_test_use_fork_; std::string filter_; std::string internal_run_death_test_; bool list_tests_; std::string output_; bool print_time_; bool print_utf8_; internal::Int32 random_seed_; internal::Int32 repeat_; bool shuffle_; internal::Int32 stack_trace_depth_; std::string stream_result_to_; bool throw_on_failure_; bool print_skipped_; } GTEST_ATTRIBUTE_UNUSED_; // Converts a Unicode code point to a narrow string in UTF-8 encoding. // code_point parameter is of type UInt32 because wchar_t may not be // wide enough to contain a code point. // If the code_point is not a valid Unicode code point // (i.e. outside of Unicode range U+0 to U+10FFFF) it will be converted // to "(Invalid Unicode 0xXXXXXXXX)". GTEST_API_ std::string CodePointToUtf8(UInt32 code_point); // Converts a wide string to a narrow string in UTF-8 encoding. // The wide string is assumed to have the following encoding: // UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin) // UTF-32 if sizeof(wchar_t) == 4 (on Linux) // Parameter str points to a null-terminated wide string. // Parameter num_chars may additionally limit the number // of wchar_t characters processed. -1 is used when the entire string // should be processed. // If the string contains code points that are not valid Unicode code points // (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output // as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding // and contains invalid UTF-16 surrogate pairs, values in those pairs // will be encoded as individual Unicode characters from Basic Normal Plane. GTEST_API_ std::string WideStringToUtf8(const wchar_t* str, int num_chars); // Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file // if the variable is present. If a file already exists at this location, this // function will write over it. If the variable is present, but the file cannot // be created, prints an error and exits. void WriteToShardStatusFileIfNeeded(); // Checks whether sharding is enabled by examining the relevant // environment variable values. If the variables are present, // but inconsistent (e.g., shard_index >= total_shards), prints // an error and exits. If in_subprocess_for_death_test, sharding is // disabled because it must only be applied to the original test // process. Otherwise, we could filter out death tests we intended to execute. GTEST_API_ bool ShouldShard(const char* total_shards_str, const char* shard_index_str, bool in_subprocess_for_death_test); // Parses the environment variable var as an Int32. If it is unset, // returns default_val. If it is not an Int32, prints an error and // and aborts. GTEST_API_ Int32 Int32FromEnvOrDie(const char* env_var, Int32 default_val); // Given the total number of shards, the shard index, and the test id, // returns true if and only if the test should be run on this shard. The test id // is some arbitrary but unique non-negative integer assigned to each test // method. Assumes that 0 <= shard_index < total_shards. GTEST_API_ bool ShouldRunTestOnShard( int total_shards, int shard_index, int test_id); // STL container utilities. // Returns the number of elements in the given container that satisfy // the given predicate. template inline int CountIf(const Container& c, Predicate predicate) { // Implemented as an explicit loop since std::count_if() in libCstd on // Solaris has a non-standard signature. int count = 0; for (typename Container::const_iterator it = c.begin(); it != c.end(); ++it) { if (predicate(*it)) ++count; } return count; } // Applies a function/functor to each element in the container. template void ForEach(const Container& c, Functor functor) { std::for_each(c.begin(), c.end(), functor); } // Returns the i-th element of the vector, or default_value if i is not // in range [0, v.size()). template inline E GetElementOr(const std::vector& v, int i, E default_value) { return (i < 0 || i >= static_cast(v.size())) ? default_value : v[static_cast(i)]; } // Performs an in-place shuffle of a range of the vector's elements. // 'begin' and 'end' are element indices as an STL-style range; // i.e. [begin, end) are shuffled, where 'end' == size() means to // shuffle to the end of the vector. template void ShuffleRange(internal::Random* random, int begin, int end, std::vector* v) { const int size = static_cast(v->size()); GTEST_CHECK_(0 <= begin && begin <= size) << "Invalid shuffle range start " << begin << ": must be in range [0, " << size << "]."; GTEST_CHECK_(begin <= end && end <= size) << "Invalid shuffle range finish " << end << ": must be in range [" << begin << ", " << size << "]."; // Fisher-Yates shuffle, from // http://en.wikipedia.org/wiki/Fisher-Yates_shuffle for (int range_width = end - begin; range_width >= 2; range_width--) { const int last_in_range = begin + range_width - 1; const int selected = begin + static_cast(random->Generate(static_cast(range_width))); std::swap((*v)[static_cast(selected)], (*v)[static_cast(last_in_range)]); } } // Performs an in-place shuffle of the vector's elements. template inline void Shuffle(internal::Random* random, std::vector* v) { ShuffleRange(random, 0, static_cast(v->size()), v); } // A function for deleting an object. Handy for being used as a // functor. template static void Delete(T* x) { delete x; } // A predicate that checks the key of a TestProperty against a known key. // // TestPropertyKeyIs is copyable. class TestPropertyKeyIs { public: // Constructor. // // TestPropertyKeyIs has NO default constructor. explicit TestPropertyKeyIs(const std::string& key) : key_(key) {} // Returns true if and only if the test name of test property matches on key_. bool operator()(const TestProperty& test_property) const { return test_property.key() == key_; } private: std::string key_; }; // Class UnitTestOptions. // // This class contains functions for processing options the user // specifies when running the tests. It has only static members. // // In most cases, the user can specify an option using either an // environment variable or a command line flag. E.g. you can set the // test filter using either GTEST_FILTER or --gtest_filter. If both // the variable and the flag are present, the latter overrides the // former. class GTEST_API_ UnitTestOptions { public: // Functions for processing the gtest_output flag. // Returns the output format, or "" for normal printed output. static std::string GetOutputFormat(); // Returns the absolute path of the requested output file, or the // default (test_detail.xml in the original working directory) if // none was explicitly specified. static std::string GetAbsolutePathToOutputFile(); // Functions for processing the gtest_filter flag. // Returns true if and only if the wildcard pattern matches the string. // The first ':' or '\0' character in pattern marks the end of it. // // This recursive algorithm isn't very efficient, but is clear and // works well enough for matching test names, which are short. static bool PatternMatchesString(const char *pattern, const char *str); // Returns true if and only if the user-specified filter matches the test // suite name and the test name. static bool FilterMatchesTest(const std::string& test_suite_name, const std::string& test_name); #if GTEST_OS_WINDOWS // Function for supporting the gtest_catch_exception flag. // Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the // given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. // This function is useful as an __except condition. static int GTestShouldProcessSEH(DWORD exception_code); #endif // GTEST_OS_WINDOWS // Returns true if "name" matches the ':' separated list of glob-style // filters in "filter". static bool MatchesFilter(const std::string& name, const char* filter); }; // Returns the current application's name, removing directory path if that // is present. Used by UnitTestOptions::GetOutputFile. GTEST_API_ FilePath GetCurrentExecutableName(); // The role interface for getting the OS stack trace as a string. class OsStackTraceGetterInterface { public: OsStackTraceGetterInterface() {} virtual ~OsStackTraceGetterInterface() {} // Returns the current OS stack trace as an std::string. Parameters: // // max_depth - the maximum number of stack frames to be included // in the trace. // skip_count - the number of top frames to be skipped; doesn't count // against max_depth. virtual std::string CurrentStackTrace(int max_depth, int skip_count) = 0; // UponLeavingGTest() should be called immediately before Google Test calls // user code. It saves some information about the current stack that // CurrentStackTrace() will use to find and hide Google Test stack frames. virtual void UponLeavingGTest() = 0; // This string is inserted in place of stack frames that are part of // Google Test's implementation. static const char* const kElidedFramesMarker; private: GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetterInterface); }; // A working implementation of the OsStackTraceGetterInterface interface. class OsStackTraceGetter : public OsStackTraceGetterInterface { public: OsStackTraceGetter() {} std::string CurrentStackTrace(int max_depth, int skip_count) override; void UponLeavingGTest() override; private: #if GTEST_HAS_ABSL Mutex mutex_; // Protects all internal state. // We save the stack frame below the frame that calls user code. // We do this because the address of the frame immediately below // the user code changes between the call to UponLeavingGTest() // and any calls to the stack trace code from within the user code. void* caller_frame_ = nullptr; #endif // GTEST_HAS_ABSL GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetter); }; // Information about a Google Test trace point. struct TraceInfo { const char* file; int line; std::string message; }; // This is the default global test part result reporter used in UnitTestImpl. // This class should only be used by UnitTestImpl. class DefaultGlobalTestPartResultReporter : public TestPartResultReporterInterface { public: explicit DefaultGlobalTestPartResultReporter(UnitTestImpl* unit_test); // Implements the TestPartResultReporterInterface. Reports the test part // result in the current test. void ReportTestPartResult(const TestPartResult& result) override; private: UnitTestImpl* const unit_test_; GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultGlobalTestPartResultReporter); }; // This is the default per thread test part result reporter used in // UnitTestImpl. This class should only be used by UnitTestImpl. class DefaultPerThreadTestPartResultReporter : public TestPartResultReporterInterface { public: explicit DefaultPerThreadTestPartResultReporter(UnitTestImpl* unit_test); // Implements the TestPartResultReporterInterface. The implementation just // delegates to the current global test part result reporter of *unit_test_. void ReportTestPartResult(const TestPartResult& result) override; private: UnitTestImpl* const unit_test_; GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultPerThreadTestPartResultReporter); }; // The private implementation of the UnitTest class. We don't protect // the methods under a mutex, as this class is not accessible by a // user and the UnitTest class that delegates work to this class does // proper locking. class GTEST_API_ UnitTestImpl { public: explicit UnitTestImpl(UnitTest* parent); virtual ~UnitTestImpl(); // There are two different ways to register your own TestPartResultReporter. // You can register your own repoter to listen either only for test results // from the current thread or for results from all threads. // By default, each per-thread test result repoter just passes a new // TestPartResult to the global test result reporter, which registers the // test part result for the currently running test. // Returns the global test part result reporter. TestPartResultReporterInterface* GetGlobalTestPartResultReporter(); // Sets the global test part result reporter. void SetGlobalTestPartResultReporter( TestPartResultReporterInterface* reporter); // Returns the test part result reporter for the current thread. TestPartResultReporterInterface* GetTestPartResultReporterForCurrentThread(); // Sets the test part result reporter for the current thread. void SetTestPartResultReporterForCurrentThread( TestPartResultReporterInterface* reporter); // Gets the number of successful test suites. int successful_test_suite_count() const; // Gets the number of failed test suites. int failed_test_suite_count() const; // Gets the number of all test suites. int total_test_suite_count() const; // Gets the number of all test suites that contain at least one test // that should run. int test_suite_to_run_count() const; // Gets the number of successful tests. int successful_test_count() const; // Gets the number of skipped tests. int skipped_test_count() const; // Gets the number of failed tests. int failed_test_count() const; // Gets the number of disabled tests that will be reported in the XML report. int reportable_disabled_test_count() const; // Gets the number of disabled tests. int disabled_test_count() const; // Gets the number of tests to be printed in the XML report. int reportable_test_count() const; // Gets the number of all tests. int total_test_count() const; // Gets the number of tests that should run. int test_to_run_count() const; // Gets the time of the test program start, in ms from the start of the // UNIX epoch. TimeInMillis start_timestamp() const { return start_timestamp_; } // Gets the elapsed time, in milliseconds. TimeInMillis elapsed_time() const { return elapsed_time_; } // Returns true if and only if the unit test passed (i.e. all test suites // passed). bool Passed() const { return !Failed(); } // Returns true if and only if the unit test failed (i.e. some test suite // failed or something outside of all tests failed). bool Failed() const { return failed_test_suite_count() > 0 || ad_hoc_test_result()->Failed(); } // Gets the i-th test suite among all the test suites. i can range from 0 to // total_test_suite_count() - 1. If i is not in that range, returns NULL. const TestSuite* GetTestSuite(int i) const { const int index = GetElementOr(test_suite_indices_, i, -1); return index < 0 ? nullptr : test_suites_[static_cast(i)]; } // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ const TestCase* GetTestCase(int i) const { return GetTestSuite(i); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Gets the i-th test suite among all the test suites. i can range from 0 to // total_test_suite_count() - 1. If i is not in that range, returns NULL. TestSuite* GetMutableSuiteCase(int i) { const int index = GetElementOr(test_suite_indices_, i, -1); return index < 0 ? nullptr : test_suites_[static_cast(index)]; } // Provides access to the event listener list. TestEventListeners* listeners() { return &listeners_; } // Returns the TestResult for the test that's currently running, or // the TestResult for the ad hoc test if no test is running. TestResult* current_test_result(); // Returns the TestResult for the ad hoc test. const TestResult* ad_hoc_test_result() const { return &ad_hoc_test_result_; } // Sets the OS stack trace getter. // // Does nothing if the input and the current OS stack trace getter // are the same; otherwise, deletes the old getter and makes the // input the current getter. void set_os_stack_trace_getter(OsStackTraceGetterInterface* getter); // Returns the current OS stack trace getter if it is not NULL; // otherwise, creates an OsStackTraceGetter, makes it the current // getter, and returns it. OsStackTraceGetterInterface* os_stack_trace_getter(); // Returns the current OS stack trace as an std::string. // // The maximum number of stack frames to be included is specified by // the gtest_stack_trace_depth flag. The skip_count parameter // specifies the number of top frames to be skipped, which doesn't // count against the number of frames to be included. // // For example, if Foo() calls Bar(), which in turn calls // CurrentOsStackTraceExceptTop(1), Foo() will be included in the // trace but Bar() and CurrentOsStackTraceExceptTop() won't. std::string CurrentOsStackTraceExceptTop(int skip_count) GTEST_NO_INLINE_; // Finds and returns a TestSuite with the given name. If one doesn't // exist, creates one and returns it. // // Arguments: // // test_suite_name: name of the test suite // type_param: the name of the test's type parameter, or NULL if // this is not a typed or a type-parameterized test. // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite TestSuite* GetTestSuite(const char* test_suite_name, const char* type_param, internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc); // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ TestCase* GetTestCase(const char* test_case_name, const char* type_param, internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc) { return GetTestSuite(test_case_name, type_param, set_up_tc, tear_down_tc); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Adds a TestInfo to the unit test. // // Arguments: // // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite // test_info: the TestInfo object void AddTestInfo(internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc, TestInfo* test_info) { // In order to support thread-safe death tests, we need to // remember the original working directory when the test program // was first invoked. We cannot do this in RUN_ALL_TESTS(), as // the user may have changed the current directory before calling // RUN_ALL_TESTS(). Therefore we capture the current directory in // AddTestInfo(), which is called to register a TEST or TEST_F // before main() is reached. if (original_working_dir_.IsEmpty()) { original_working_dir_.Set(FilePath::GetCurrentDir()); GTEST_CHECK_(!original_working_dir_.IsEmpty()) << "Failed to get the current working directory."; } GetTestSuite(test_info->test_suite_name(), test_info->type_param(), set_up_tc, tear_down_tc) ->AddTestInfo(test_info); } // Returns ParameterizedTestSuiteRegistry object used to keep track of // value-parameterized tests and instantiate and register them. internal::ParameterizedTestSuiteRegistry& parameterized_test_registry() { return parameterized_test_registry_; } // Sets the TestSuite object for the test that's currently running. void set_current_test_suite(TestSuite* a_current_test_suite) { current_test_suite_ = a_current_test_suite; } // Sets the TestInfo object for the test that's currently running. If // current_test_info is NULL, the assertion results will be stored in // ad_hoc_test_result_. void set_current_test_info(TestInfo* a_current_test_info) { current_test_info_ = a_current_test_info; } // Registers all parameterized tests defined using TEST_P and // INSTANTIATE_TEST_SUITE_P, creating regular tests for each test/parameter // combination. This method can be called more then once; it has guards // protecting from registering the tests more then once. If // value-parameterized tests are disabled, RegisterParameterizedTests is // present but does nothing. void RegisterParameterizedTests(); // Runs all tests in this UnitTest object, prints the result, and // returns true if all tests are successful. If any exception is // thrown during a test, this test is considered to be failed, but // the rest of the tests will still be run. bool RunAllTests(); // Clears the results of all tests, except the ad hoc tests. void ClearNonAdHocTestResult() { ForEach(test_suites_, TestSuite::ClearTestSuiteResult); } // Clears the results of ad-hoc test assertions. void ClearAdHocTestResult() { ad_hoc_test_result_.Clear(); } // Adds a TestProperty to the current TestResult object when invoked in a // context of a test or a test suite, or to the global property set. If the // result already contains a property with the same key, the value will be // updated. void RecordProperty(const TestProperty& test_property); enum ReactionToSharding { HONOR_SHARDING_PROTOCOL, IGNORE_SHARDING_PROTOCOL }; // Matches the full name of each test against the user-specified // filter to decide whether the test should run, then records the // result in each TestSuite and TestInfo object. // If shard_tests == HONOR_SHARDING_PROTOCOL, further filters tests // based on sharding variables in the environment. // Returns the number of tests that should run. int FilterTests(ReactionToSharding shard_tests); // Prints the names of the tests matching the user-specified filter flag. void ListTestsMatchingFilter(); const TestSuite* current_test_suite() const { return current_test_suite_; } TestInfo* current_test_info() { return current_test_info_; } const TestInfo* current_test_info() const { return current_test_info_; } // Returns the vector of environments that need to be set-up/torn-down // before/after the tests are run. std::vector& environments() { return environments_; } // Getters for the per-thread Google Test trace stack. std::vector& gtest_trace_stack() { return *(gtest_trace_stack_.pointer()); } const std::vector& gtest_trace_stack() const { return gtest_trace_stack_.get(); } #if GTEST_HAS_DEATH_TEST void InitDeathTestSubprocessControlInfo() { internal_run_death_test_flag_.reset(ParseInternalRunDeathTestFlag()); } // Returns a pointer to the parsed --gtest_internal_run_death_test // flag, or NULL if that flag was not specified. // This information is useful only in a death test child process. // Must not be called before a call to InitGoogleTest. const InternalRunDeathTestFlag* internal_run_death_test_flag() const { return internal_run_death_test_flag_.get(); } // Returns a pointer to the current death test factory. internal::DeathTestFactory* death_test_factory() { return death_test_factory_.get(); } void SuppressTestEventsIfInSubprocess(); friend class ReplaceDeathTestFactory; #endif // GTEST_HAS_DEATH_TEST // Initializes the event listener performing XML output as specified by // UnitTestOptions. Must not be called before InitGoogleTest. void ConfigureXmlOutput(); #if GTEST_CAN_STREAM_RESULTS_ // Initializes the event listener for streaming test results to a socket. // Must not be called before InitGoogleTest. void ConfigureStreamingOutput(); #endif // Performs initialization dependent upon flag values obtained in // ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to // ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest // this function is also called from RunAllTests. Since this function can be // called more than once, it has to be idempotent. void PostFlagParsingInit(); // Gets the random seed used at the start of the current test iteration. int random_seed() const { return random_seed_; } // Gets the random number generator. internal::Random* random() { return &random_; } // Shuffles all test suites, and the tests within each test suite, // making sure that death tests are still run first. void ShuffleTests(); // Restores the test suites and tests to their order before the first shuffle. void UnshuffleTests(); // Returns the value of GTEST_FLAG(catch_exceptions) at the moment // UnitTest::Run() starts. bool catch_exceptions() const { return catch_exceptions_; } private: friend class ::testing::UnitTest; // Used by UnitTest::Run() to capture the state of // GTEST_FLAG(catch_exceptions) at the moment it starts. void set_catch_exceptions(bool value) { catch_exceptions_ = value; } // The UnitTest object that owns this implementation object. UnitTest* const parent_; // The working directory when the first TEST() or TEST_F() was // executed. internal::FilePath original_working_dir_; // The default test part result reporters. DefaultGlobalTestPartResultReporter default_global_test_part_result_reporter_; DefaultPerThreadTestPartResultReporter default_per_thread_test_part_result_reporter_; // Points to (but doesn't own) the global test part result reporter. TestPartResultReporterInterface* global_test_part_result_repoter_; // Protects read and write access to global_test_part_result_reporter_. internal::Mutex global_test_part_result_reporter_mutex_; // Points to (but doesn't own) the per-thread test part result reporter. internal::ThreadLocal per_thread_test_part_result_reporter_; // The vector of environments that need to be set-up/torn-down // before/after the tests are run. std::vector environments_; // The vector of TestSuites in their original order. It owns the // elements in the vector. std::vector test_suites_; // Provides a level of indirection for the test suite list to allow // easy shuffling and restoring the test suite order. The i-th // element of this vector is the index of the i-th test suite in the // shuffled order. std::vector test_suite_indices_; // ParameterizedTestRegistry object used to register value-parameterized // tests. internal::ParameterizedTestSuiteRegistry parameterized_test_registry_; // Indicates whether RegisterParameterizedTests() has been called already. bool parameterized_tests_registered_; // Index of the last death test suite registered. Initially -1. int last_death_test_suite_; // This points to the TestSuite for the currently running test. It // changes as Google Test goes through one test suite after another. // When no test is running, this is set to NULL and Google Test // stores assertion results in ad_hoc_test_result_. Initially NULL. TestSuite* current_test_suite_; // This points to the TestInfo for the currently running test. It // changes as Google Test goes through one test after another. When // no test is running, this is set to NULL and Google Test stores // assertion results in ad_hoc_test_result_. Initially NULL. TestInfo* current_test_info_; // Normally, a user only writes assertions inside a TEST or TEST_F, // or inside a function called by a TEST or TEST_F. Since Google // Test keeps track of which test is current running, it can // associate such an assertion with the test it belongs to. // // If an assertion is encountered when no TEST or TEST_F is running, // Google Test attributes the assertion result to an imaginary "ad hoc" // test, and records the result in ad_hoc_test_result_. TestResult ad_hoc_test_result_; // The list of event listeners that can be used to track events inside // Google Test. TestEventListeners listeners_; // The OS stack trace getter. Will be deleted when the UnitTest // object is destructed. By default, an OsStackTraceGetter is used, // but the user can set this field to use a custom getter if that is // desired. OsStackTraceGetterInterface* os_stack_trace_getter_; // True if and only if PostFlagParsingInit() has been called. bool post_flag_parse_init_performed_; // The random number seed used at the beginning of the test run. int random_seed_; // Our random number generator. internal::Random random_; // The time of the test program start, in ms from the start of the // UNIX epoch. TimeInMillis start_timestamp_; // How long the test took to run, in milliseconds. TimeInMillis elapsed_time_; #if GTEST_HAS_DEATH_TEST // The decomposed components of the gtest_internal_run_death_test flag, // parsed when RUN_ALL_TESTS is called. std::unique_ptr internal_run_death_test_flag_; std::unique_ptr death_test_factory_; #endif // GTEST_HAS_DEATH_TEST // A per-thread stack of traces created by the SCOPED_TRACE() macro. internal::ThreadLocal > gtest_trace_stack_; // The value of GTEST_FLAG(catch_exceptions) at the moment RunAllTests() // starts. bool catch_exceptions_; GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTestImpl); }; // class UnitTestImpl // Convenience function for accessing the global UnitTest // implementation object. inline UnitTestImpl* GetUnitTestImpl() { return UnitTest::GetInstance()->impl(); } #if GTEST_USES_SIMPLE_RE // Internal helper functions for implementing the simple regular // expression matcher. GTEST_API_ bool IsInSet(char ch, const char* str); GTEST_API_ bool IsAsciiDigit(char ch); GTEST_API_ bool IsAsciiPunct(char ch); GTEST_API_ bool IsRepeat(char ch); GTEST_API_ bool IsAsciiWhiteSpace(char ch); GTEST_API_ bool IsAsciiWordChar(char ch); GTEST_API_ bool IsValidEscape(char ch); GTEST_API_ bool AtomMatchesChar(bool escaped, char pattern, char ch); GTEST_API_ bool ValidateRegex(const char* regex); GTEST_API_ bool MatchRegexAtHead(const char* regex, const char* str); GTEST_API_ bool MatchRepetitionAndRegexAtHead( bool escaped, char ch, char repeat, const char* regex, const char* str); GTEST_API_ bool MatchRegexAnywhere(const char* regex, const char* str); #endif // GTEST_USES_SIMPLE_RE // Parses the command line for Google Test flags, without initializing // other parts of Google Test. GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, char** argv); GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv); #if GTEST_HAS_DEATH_TEST // Returns the message describing the last system error, regardless of the // platform. GTEST_API_ std::string GetLastErrnoDescription(); // Attempts to parse a string into a positive integer pointed to by the // number parameter. Returns true if that is possible. // GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we can use // it here. template bool ParseNaturalNumber(const ::std::string& str, Integer* number) { // Fail fast if the given string does not begin with a digit; // this bypasses strtoXXX's "optional leading whitespace and plus // or minus sign" semantics, which are undesirable here. if (str.empty() || !IsDigit(str[0])) { return false; } errno = 0; char* end; // BiggestConvertible is the largest integer type that system-provided // string-to-number conversion routines can return. # if GTEST_OS_WINDOWS && !defined(__GNUC__) // MSVC and C++ Builder define __int64 instead of the standard long long. typedef unsigned __int64 BiggestConvertible; const BiggestConvertible parsed = _strtoui64(str.c_str(), &end, 10); # else typedef unsigned long long BiggestConvertible; // NOLINT const BiggestConvertible parsed = strtoull(str.c_str(), &end, 10); # endif // GTEST_OS_WINDOWS && !defined(__GNUC__) const bool parse_success = *end == '\0' && errno == 0; GTEST_CHECK_(sizeof(Integer) <= sizeof(parsed)); const Integer result = static_cast(parsed); if (parse_success && static_cast(result) == parsed) { *number = result; return true; } return false; } #endif // GTEST_HAS_DEATH_TEST // TestResult contains some private methods that should be hidden from // Google Test user but are required for testing. This class allow our tests // to access them. // // This class is supplied only for the purpose of testing Google Test's own // constructs. Do not use it in user tests, either directly or indirectly. class TestResultAccessor { public: static void RecordProperty(TestResult* test_result, const std::string& xml_element, const TestProperty& property) { test_result->RecordProperty(xml_element, property); } static void ClearTestPartResults(TestResult* test_result) { test_result->ClearTestPartResults(); } static const std::vector& test_part_results( const TestResult& test_result) { return test_result.test_part_results(); } }; #if GTEST_CAN_STREAM_RESULTS_ // Streams test results to the given port on the given host machine. class StreamingListener : public EmptyTestEventListener { public: // Abstract base class for writing strings to a socket. class AbstractSocketWriter { public: virtual ~AbstractSocketWriter() {} // Sends a string to the socket. virtual void Send(const std::string& message) = 0; // Closes the socket. virtual void CloseConnection() {} // Sends a string and a newline to the socket. void SendLn(const std::string& message) { Send(message + "\n"); } }; // Concrete class for actually writing strings to a socket. class SocketWriter : public AbstractSocketWriter { public: SocketWriter(const std::string& host, const std::string& port) : sockfd_(-1), host_name_(host), port_num_(port) { MakeConnection(); } ~SocketWriter() override { if (sockfd_ != -1) CloseConnection(); } // Sends a string to the socket. void Send(const std::string& message) override { GTEST_CHECK_(sockfd_ != -1) << "Send() can be called only when there is a connection."; const auto len = static_cast(message.length()); if (write(sockfd_, message.c_str(), len) != static_cast(len)) { GTEST_LOG_(WARNING) << "stream_result_to: failed to stream to " << host_name_ << ":" << port_num_; } } private: // Creates a client socket and connects to the server. void MakeConnection(); // Closes the socket. void CloseConnection() override { GTEST_CHECK_(sockfd_ != -1) << "CloseConnection() can be called only when there is a connection."; close(sockfd_); sockfd_ = -1; } int sockfd_; // socket file descriptor const std::string host_name_; const std::string port_num_; GTEST_DISALLOW_COPY_AND_ASSIGN_(SocketWriter); }; // class SocketWriter // Escapes '=', '&', '%', and '\n' characters in str as "%xx". static std::string UrlEncode(const char* str); StreamingListener(const std::string& host, const std::string& port) : socket_writer_(new SocketWriter(host, port)) { Start(); } explicit StreamingListener(AbstractSocketWriter* socket_writer) : socket_writer_(socket_writer) { Start(); } void OnTestProgramStart(const UnitTest& /* unit_test */) override { SendLn("event=TestProgramStart"); } void OnTestProgramEnd(const UnitTest& unit_test) override { // Note that Google Test current only report elapsed time for each // test iteration, not for the entire test program. SendLn("event=TestProgramEnd&passed=" + FormatBool(unit_test.Passed())); // Notify the streaming server to stop. socket_writer_->CloseConnection(); } void OnTestIterationStart(const UnitTest& /* unit_test */, int iteration) override { SendLn("event=TestIterationStart&iteration=" + StreamableToString(iteration)); } void OnTestIterationEnd(const UnitTest& unit_test, int /* iteration */) override { SendLn("event=TestIterationEnd&passed=" + FormatBool(unit_test.Passed()) + "&elapsed_time=" + StreamableToString(unit_test.elapsed_time()) + "ms"); } // Note that "event=TestCaseStart" is a wire format and has to remain // "case" for compatibilty void OnTestCaseStart(const TestCase& test_case) override { SendLn(std::string("event=TestCaseStart&name=") + test_case.name()); } // Note that "event=TestCaseEnd" is a wire format and has to remain // "case" for compatibilty void OnTestCaseEnd(const TestCase& test_case) override { SendLn("event=TestCaseEnd&passed=" + FormatBool(test_case.Passed()) + "&elapsed_time=" + StreamableToString(test_case.elapsed_time()) + "ms"); } void OnTestStart(const TestInfo& test_info) override { SendLn(std::string("event=TestStart&name=") + test_info.name()); } void OnTestEnd(const TestInfo& test_info) override { SendLn("event=TestEnd&passed=" + FormatBool((test_info.result())->Passed()) + "&elapsed_time=" + StreamableToString((test_info.result())->elapsed_time()) + "ms"); } void OnTestPartResult(const TestPartResult& test_part_result) override { const char* file_name = test_part_result.file_name(); if (file_name == nullptr) file_name = ""; SendLn("event=TestPartResult&file=" + UrlEncode(file_name) + "&line=" + StreamableToString(test_part_result.line_number()) + "&message=" + UrlEncode(test_part_result.message())); } private: // Sends the given message and a newline to the socket. void SendLn(const std::string& message) { socket_writer_->SendLn(message); } // Called at the start of streaming to notify the receiver what // protocol we are using. void Start() { SendLn("gtest_streaming_protocol_version=1.0"); } std::string FormatBool(bool value) { return value ? "1" : "0"; } const std::unique_ptr socket_writer_; GTEST_DISALLOW_COPY_AND_ASSIGN_(StreamingListener); }; // class StreamingListener #endif // GTEST_CAN_STREAM_RESULTS_ } // namespace internal } // namespace testing GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 #endif // GTEST_SRC_GTEST_INTERNAL_INL_H_ #if GTEST_OS_WINDOWS # define vsnprintf _vsnprintf #endif // GTEST_OS_WINDOWS #if GTEST_OS_MAC #ifndef GTEST_OS_IOS #include #endif #endif #if GTEST_HAS_ABSL #include "absl/debugging/failure_signal_handler.h" #include "absl/debugging/stacktrace.h" #include "absl/debugging/symbolize.h" #include "absl/strings/str_cat.h" #endif // GTEST_HAS_ABSL namespace testing { using internal::CountIf; using internal::ForEach; using internal::GetElementOr; using internal::Shuffle; // Constants. // A test whose test suite name or test name matches this filter is // disabled and not run. static const char kDisableTestFilter[] = "DISABLED_*:*/DISABLED_*"; // A test suite whose name matches this filter is considered a death // test suite and will be run before test suites whose name doesn't // match this filter. static const char kDeathTestSuiteFilter[] = "*DeathTest:*DeathTest/*"; // A test filter that matches everything. static const char kUniversalFilter[] = "*"; // The default output format. static const char kDefaultOutputFormat[] = "xml"; // The default output file. static const char kDefaultOutputFile[] = "test_detail"; // The environment variable name for the test shard index. static const char kTestShardIndex[] = "GTEST_SHARD_INDEX"; // The environment variable name for the total number of test shards. static const char kTestTotalShards[] = "GTEST_TOTAL_SHARDS"; // The environment variable name for the test shard status file. static const char kTestShardStatusFile[] = "GTEST_SHARD_STATUS_FILE"; namespace internal { // The text used in failure messages to indicate the start of the // stack trace. const char kStackTraceMarker[] = "\nStack trace:\n"; // g_help_flag is true if and only if the --help flag or an equivalent form // is specified on the command line. bool g_help_flag = false; // Utilty function to Open File for Writing static FILE* OpenFileForWriting(const std::string& output_file) { FILE* fileout = nullptr; FilePath output_file_path(output_file); FilePath output_dir(output_file_path.RemoveFileName()); if (output_dir.CreateDirectoriesRecursively()) { fileout = posix::FOpen(output_file.c_str(), "w"); } if (fileout == nullptr) { GTEST_LOG_(FATAL) << "Unable to open file \"" << output_file << "\""; } return fileout; } } // namespace internal // Bazel passes in the argument to '--test_filter' via the TESTBRIDGE_TEST_ONLY // environment variable. static const char* GetDefaultFilter() { const char* const testbridge_test_only = internal::posix::GetEnv("TESTBRIDGE_TEST_ONLY"); if (testbridge_test_only != nullptr) { return testbridge_test_only; } return kUniversalFilter; } GTEST_DEFINE_bool_( also_run_disabled_tests, internal::BoolFromGTestEnv("also_run_disabled_tests", false), "Run disabled tests too, in addition to the tests normally being run."); GTEST_DEFINE_bool_( break_on_failure, internal::BoolFromGTestEnv("break_on_failure", false), "True if and only if a failed assertion should be a debugger " "break-point."); GTEST_DEFINE_bool_(catch_exceptions, internal::BoolFromGTestEnv("catch_exceptions", true), "True if and only if " GTEST_NAME_ " should catch exceptions and treat them as test failures."); GTEST_DEFINE_string_( color, internal::StringFromGTestEnv("color", "auto"), "Whether to use colors in the output. Valid values: yes, no, " "and auto. 'auto' means to use colors if the output is " "being sent to a terminal and the TERM environment variable " "is set to a terminal type that supports colors."); GTEST_DEFINE_string_( filter, internal::StringFromGTestEnv("filter", GetDefaultFilter()), "A colon-separated list of glob (not regex) patterns " "for filtering the tests to run, optionally followed by a " "'-' and a : separated list of negative patterns (tests to " "exclude). A test is run if it matches one of the positive " "patterns and does not match any of the negative patterns."); GTEST_DEFINE_bool_( install_failure_signal_handler, internal::BoolFromGTestEnv("install_failure_signal_handler", false), "If true and supported on the current platform, " GTEST_NAME_ " should " "install a signal handler that dumps debugging information when fatal " "signals are raised."); GTEST_DEFINE_bool_(list_tests, false, "List all tests without running them."); // The net priority order after flag processing is thus: // --gtest_output command line flag // GTEST_OUTPUT environment variable // XML_OUTPUT_FILE environment variable // '' GTEST_DEFINE_string_( output, internal::StringFromGTestEnv("output", internal::OutputFlagAlsoCheckEnvVar().c_str()), "A format (defaults to \"xml\" but can be specified to be \"json\"), " "optionally followed by a colon and an output file name or directory. " "A directory is indicated by a trailing pathname separator. " "Examples: \"xml:filename.xml\", \"xml::directoryname/\". " "If a directory is specified, output files will be created " "within that directory, with file-names based on the test " "executable's name and, if necessary, made unique by adding " "digits."); GTEST_DEFINE_bool_(print_time, internal::BoolFromGTestEnv("print_time", true), "True if and only if " GTEST_NAME_ " should display elapsed time in text output."); GTEST_DEFINE_bool_(print_utf8, internal::BoolFromGTestEnv("print_utf8", true), "True if and only if " GTEST_NAME_ " prints UTF8 characters as text."); GTEST_DEFINE_int32_( random_seed, internal::Int32FromGTestEnv("random_seed", 0), "Random number seed to use when shuffling test orders. Must be in range " "[1, 99999], or 0 to use a seed based on the current time."); GTEST_DEFINE_int32_( repeat, internal::Int32FromGTestEnv("repeat", 1), "How many times to repeat each test. Specify a negative number " "for repeating forever. Useful for shaking out flaky tests."); GTEST_DEFINE_bool_(show_internal_stack_frames, false, "True if and only if " GTEST_NAME_ " should include internal stack frames when " "printing test failure stack traces."); GTEST_DEFINE_bool_(shuffle, internal::BoolFromGTestEnv("shuffle", false), "True if and only if " GTEST_NAME_ " should randomize tests' order on every run."); GTEST_DEFINE_int32_( stack_trace_depth, internal::Int32FromGTestEnv("stack_trace_depth", kMaxStackTraceDepth), "The maximum number of stack frames to print when an " "assertion fails. The valid range is 0 through 100, inclusive."); GTEST_DEFINE_string_( stream_result_to, internal::StringFromGTestEnv("stream_result_to", ""), "This flag specifies the host name and the port number on which to stream " "test results. Example: \"localhost:555\". The flag is effective only on " "Linux."); GTEST_DEFINE_bool_( throw_on_failure, internal::BoolFromGTestEnv("throw_on_failure", false), "When this flag is specified, a failed assertion will throw an exception " "if exceptions are enabled or exit the program with a non-zero code " "otherwise. For use with an external test framework."); GTEST_DEFINE_bool_( print_skipped, internal::BoolFromGTestEnv("print_skipped", false), "When this flag is specified, list of skipped test names is printed in " "summary"); #if GTEST_USE_OWN_FLAGFILE_FLAG_ GTEST_DEFINE_string_( flagfile, internal::StringFromGTestEnv("flagfile", ""), "This flag specifies the flagfile to read command-line flags from."); #endif // GTEST_USE_OWN_FLAGFILE_FLAG_ namespace internal { // Generates a random number from [0, range), using a Linear // Congruential Generator (LCG). Crashes if 'range' is 0 or greater // than kMaxRange. UInt32 Random::Generate(UInt32 range) { // These constants are the same as are used in glibc's rand(3). // Use wider types than necessary to prevent unsigned overflow diagnostics. state_ = static_cast(1103515245ULL*state_ + 12345U) % kMaxRange; GTEST_CHECK_(range > 0) << "Cannot generate a number in the range [0, 0)."; GTEST_CHECK_(range <= kMaxRange) << "Generation of a number in [0, " << range << ") was requested, " << "but this can only generate numbers in [0, " << kMaxRange << ")."; // Converting via modulus introduces a bit of downward bias, but // it's simple, and a linear congruential generator isn't too good // to begin with. return state_ % range; } // GTestIsInitialized() returns true if and only if the user has initialized // Google Test. Useful for catching the user mistake of not initializing // Google Test before calling RUN_ALL_TESTS(). static bool GTestIsInitialized() { return GetArgvs().size() > 0; } // Iterates over a vector of TestSuites, keeping a running sum of the // results of calling a given int-returning method on each. // Returns the sum. static int SumOverTestSuiteList(const std::vector& case_list, int (TestSuite::*method)() const) { int sum = 0; for (size_t i = 0; i < case_list.size(); i++) { sum += (case_list[i]->*method)(); } return sum; } // Returns true if and only if the test suite passed. static bool TestSuitePassed(const TestSuite* test_suite) { return test_suite->should_run() && test_suite->Passed(); } // Returns true if and only if the test suite failed. static bool TestSuiteFailed(const TestSuite* test_suite) { return test_suite->should_run() && test_suite->Failed(); } // Returns true if and only if test_suite contains at least one test that // should run. static bool ShouldRunTestSuite(const TestSuite* test_suite) { return test_suite->should_run(); } // AssertHelper constructor. AssertHelper::AssertHelper(TestPartResult::Type type, const char* file, int line, const char* message) : data_(new AssertHelperData(type, file, line, message)) { } AssertHelper::~AssertHelper() { delete data_; } // Message assignment, for assertion streaming support. void AssertHelper::operator=(const Message& message) const { UnitTest::GetInstance()-> AddTestPartResult(data_->type, data_->file, data_->line, AppendUserMessage(data_->message, message), UnitTest::GetInstance()->impl() ->CurrentOsStackTraceExceptTop(1) // Skips the stack frame for this function itself. ); // NOLINT } // A copy of all command line arguments. Set by InitGoogleTest(). static ::std::vector g_argvs; ::std::vector GetArgvs() { #if defined(GTEST_CUSTOM_GET_ARGVS_) // GTEST_CUSTOM_GET_ARGVS_() may return a container of std::string or // ::string. This code converts it to the appropriate type. const auto& custom = GTEST_CUSTOM_GET_ARGVS_(); return ::std::vector(custom.begin(), custom.end()); #else // defined(GTEST_CUSTOM_GET_ARGVS_) return g_argvs; #endif // defined(GTEST_CUSTOM_GET_ARGVS_) } // Returns the current application's name, removing directory path if that // is present. FilePath GetCurrentExecutableName() { FilePath result; #if GTEST_OS_WINDOWS || GTEST_OS_OS2 result.Set(FilePath(GetArgvs()[0]).RemoveExtension("exe")); #else result.Set(FilePath(GetArgvs()[0])); #endif // GTEST_OS_WINDOWS return result.RemoveDirectoryName(); } // Functions for processing the gtest_output flag. // Returns the output format, or "" for normal printed output. std::string UnitTestOptions::GetOutputFormat() { const char* const gtest_output_flag = GTEST_FLAG(output).c_str(); const char* const colon = strchr(gtest_output_flag, ':'); return (colon == nullptr) ? std::string(gtest_output_flag) : std::string(gtest_output_flag, static_cast(colon - gtest_output_flag)); } // Returns the name of the requested output file, or the default if none // was explicitly specified. std::string UnitTestOptions::GetAbsolutePathToOutputFile() { const char* const gtest_output_flag = GTEST_FLAG(output).c_str(); std::string format = GetOutputFormat(); if (format.empty()) format = std::string(kDefaultOutputFormat); const char* const colon = strchr(gtest_output_flag, ':'); if (colon == nullptr) return internal::FilePath::MakeFileName( internal::FilePath( UnitTest::GetInstance()->original_working_dir()), internal::FilePath(kDefaultOutputFile), 0, format.c_str()).string(); internal::FilePath output_name(colon + 1); if (!output_name.IsAbsolutePath()) output_name = internal::FilePath::ConcatPaths( internal::FilePath(UnitTest::GetInstance()->original_working_dir()), internal::FilePath(colon + 1)); if (!output_name.IsDirectory()) return output_name.string(); internal::FilePath result(internal::FilePath::GenerateUniqueFileName( output_name, internal::GetCurrentExecutableName(), GetOutputFormat().c_str())); return result.string(); } // Returns true if and only if the wildcard pattern matches the string. // The first ':' or '\0' character in pattern marks the end of it. // // This recursive algorithm isn't very efficient, but is clear and // works well enough for matching test names, which are short. bool UnitTestOptions::PatternMatchesString(const char *pattern, const char *str) { switch (*pattern) { case '\0': case ':': // Either ':' or '\0' marks the end of the pattern. return *str == '\0'; case '?': // Matches any single character. return *str != '\0' && PatternMatchesString(pattern + 1, str + 1); case '*': // Matches any string (possibly empty) of characters. return (*str != '\0' && PatternMatchesString(pattern, str + 1)) || PatternMatchesString(pattern + 1, str); default: // Non-special character. Matches itself. return *pattern == *str && PatternMatchesString(pattern + 1, str + 1); } } bool UnitTestOptions::MatchesFilter( const std::string& name, const char* filter) { const char *cur_pattern = filter; for (;;) { if (PatternMatchesString(cur_pattern, name.c_str())) { return true; } // Finds the next pattern in the filter. cur_pattern = strchr(cur_pattern, ':'); // Returns if no more pattern can be found. if (cur_pattern == nullptr) { return false; } // Skips the pattern separater (the ':' character). cur_pattern++; } } // Returns true if and only if the user-specified filter matches the test // suite name and the test name. bool UnitTestOptions::FilterMatchesTest(const std::string& test_suite_name, const std::string& test_name) { const std::string& full_name = test_suite_name + "." + test_name.c_str(); // Split --gtest_filter at '-', if there is one, to separate into // positive filter and negative filter portions const char* const p = GTEST_FLAG(filter).c_str(); const char* const dash = strchr(p, '-'); std::string positive; std::string negative; if (dash == nullptr) { positive = GTEST_FLAG(filter).c_str(); // Whole string is a positive filter negative = ""; } else { positive = std::string(p, dash); // Everything up to the dash negative = std::string(dash + 1); // Everything after the dash if (positive.empty()) { // Treat '-test1' as the same as '*-test1' positive = kUniversalFilter; } } // A filter is a colon-separated list of patterns. It matches a // test if any pattern in it matches the test. return (MatchesFilter(full_name, positive.c_str()) && !MatchesFilter(full_name, negative.c_str())); } #if GTEST_HAS_SEH // Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the // given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. // This function is useful as an __except condition. int UnitTestOptions::GTestShouldProcessSEH(DWORD exception_code) { // Google Test should handle a SEH exception if: // 1. the user wants it to, AND // 2. this is not a breakpoint exception, AND // 3. this is not a C++ exception (VC++ implements them via SEH, // apparently). // // SEH exception code for C++ exceptions. // (see http://support.microsoft.com/kb/185294 for more information). const DWORD kCxxExceptionCode = 0xe06d7363; bool should_handle = true; if (!GTEST_FLAG(catch_exceptions)) should_handle = false; else if (exception_code == EXCEPTION_BREAKPOINT) should_handle = false; else if (exception_code == kCxxExceptionCode) should_handle = false; return should_handle ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH; } #endif // GTEST_HAS_SEH } // namespace internal // The c'tor sets this object as the test part result reporter used by // Google Test. The 'result' parameter specifies where to report the // results. Intercepts only failures from the current thread. ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( TestPartResultArray* result) : intercept_mode_(INTERCEPT_ONLY_CURRENT_THREAD), result_(result) { Init(); } // The c'tor sets this object as the test part result reporter used by // Google Test. The 'result' parameter specifies where to report the // results. ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( InterceptMode intercept_mode, TestPartResultArray* result) : intercept_mode_(intercept_mode), result_(result) { Init(); } void ScopedFakeTestPartResultReporter::Init() { internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); if (intercept_mode_ == INTERCEPT_ALL_THREADS) { old_reporter_ = impl->GetGlobalTestPartResultReporter(); impl->SetGlobalTestPartResultReporter(this); } else { old_reporter_ = impl->GetTestPartResultReporterForCurrentThread(); impl->SetTestPartResultReporterForCurrentThread(this); } } // The d'tor restores the test part result reporter used by Google Test // before. ScopedFakeTestPartResultReporter::~ScopedFakeTestPartResultReporter() { internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); if (intercept_mode_ == INTERCEPT_ALL_THREADS) { impl->SetGlobalTestPartResultReporter(old_reporter_); } else { impl->SetTestPartResultReporterForCurrentThread(old_reporter_); } } // Increments the test part result count and remembers the result. // This method is from the TestPartResultReporterInterface interface. void ScopedFakeTestPartResultReporter::ReportTestPartResult( const TestPartResult& result) { result_->Append(result); } namespace internal { // Returns the type ID of ::testing::Test. We should always call this // instead of GetTypeId< ::testing::Test>() to get the type ID of // testing::Test. This is to work around a suspected linker bug when // using Google Test as a framework on Mac OS X. The bug causes // GetTypeId< ::testing::Test>() to return different values depending // on whether the call is from the Google Test framework itself or // from user test code. GetTestTypeId() is guaranteed to always // return the same value, as it always calls GetTypeId<>() from the // gtest.cc, which is within the Google Test framework. TypeId GetTestTypeId() { return GetTypeId(); } // The value of GetTestTypeId() as seen from within the Google Test // library. This is solely for testing GetTestTypeId(). extern const TypeId kTestTypeIdInGoogleTest = GetTestTypeId(); // This predicate-formatter checks that 'results' contains a test part // failure of the given type and that the failure message contains the // given substring. static AssertionResult HasOneFailure(const char* /* results_expr */, const char* /* type_expr */, const char* /* substr_expr */, const TestPartResultArray& results, TestPartResult::Type type, const std::string& substr) { const std::string expected(type == TestPartResult::kFatalFailure ? "1 fatal failure" : "1 non-fatal failure"); Message msg; if (results.size() != 1) { msg << "Expected: " << expected << "\n" << " Actual: " << results.size() << " failures"; for (int i = 0; i < results.size(); i++) { msg << "\n" << results.GetTestPartResult(i); } return AssertionFailure() << msg; } const TestPartResult& r = results.GetTestPartResult(0); if (r.type() != type) { return AssertionFailure() << "Expected: " << expected << "\n" << " Actual:\n" << r; } if (strstr(r.message(), substr.c_str()) == nullptr) { return AssertionFailure() << "Expected: " << expected << " containing \"" << substr << "\"\n" << " Actual:\n" << r; } return AssertionSuccess(); } // The constructor of SingleFailureChecker remembers where to look up // test part results, what type of failure we expect, and what // substring the failure message should contain. SingleFailureChecker::SingleFailureChecker(const TestPartResultArray* results, TestPartResult::Type type, const std::string& substr) : results_(results), type_(type), substr_(substr) {} // The destructor of SingleFailureChecker verifies that the given // TestPartResultArray contains exactly one failure that has the given // type and contains the given substring. If that's not the case, a // non-fatal failure will be generated. SingleFailureChecker::~SingleFailureChecker() { EXPECT_PRED_FORMAT3(HasOneFailure, *results_, type_, substr_); } DefaultGlobalTestPartResultReporter::DefaultGlobalTestPartResultReporter( UnitTestImpl* unit_test) : unit_test_(unit_test) {} void DefaultGlobalTestPartResultReporter::ReportTestPartResult( const TestPartResult& result) { unit_test_->current_test_result()->AddTestPartResult(result); unit_test_->listeners()->repeater()->OnTestPartResult(result); } DefaultPerThreadTestPartResultReporter::DefaultPerThreadTestPartResultReporter( UnitTestImpl* unit_test) : unit_test_(unit_test) {} void DefaultPerThreadTestPartResultReporter::ReportTestPartResult( const TestPartResult& result) { unit_test_->GetGlobalTestPartResultReporter()->ReportTestPartResult(result); } // Returns the global test part result reporter. TestPartResultReporterInterface* UnitTestImpl::GetGlobalTestPartResultReporter() { internal::MutexLock lock(&global_test_part_result_reporter_mutex_); return global_test_part_result_repoter_; } // Sets the global test part result reporter. void UnitTestImpl::SetGlobalTestPartResultReporter( TestPartResultReporterInterface* reporter) { internal::MutexLock lock(&global_test_part_result_reporter_mutex_); global_test_part_result_repoter_ = reporter; } // Returns the test part result reporter for the current thread. TestPartResultReporterInterface* UnitTestImpl::GetTestPartResultReporterForCurrentThread() { return per_thread_test_part_result_reporter_.get(); } // Sets the test part result reporter for the current thread. void UnitTestImpl::SetTestPartResultReporterForCurrentThread( TestPartResultReporterInterface* reporter) { per_thread_test_part_result_reporter_.set(reporter); } // Gets the number of successful test suites. int UnitTestImpl::successful_test_suite_count() const { return CountIf(test_suites_, TestSuitePassed); } // Gets the number of failed test suites. int UnitTestImpl::failed_test_suite_count() const { return CountIf(test_suites_, TestSuiteFailed); } // Gets the number of all test suites. int UnitTestImpl::total_test_suite_count() const { return static_cast(test_suites_.size()); } // Gets the number of all test suites that contain at least one test // that should run. int UnitTestImpl::test_suite_to_run_count() const { return CountIf(test_suites_, ShouldRunTestSuite); } // Gets the number of successful tests. int UnitTestImpl::successful_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::successful_test_count); } // Gets the number of skipped tests. int UnitTestImpl::skipped_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::skipped_test_count); } // Gets the number of failed tests. int UnitTestImpl::failed_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::failed_test_count); } // Gets the number of disabled tests that will be reported in the XML report. int UnitTestImpl::reportable_disabled_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::reportable_disabled_test_count); } // Gets the number of disabled tests. int UnitTestImpl::disabled_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::disabled_test_count); } // Gets the number of tests to be printed in the XML report. int UnitTestImpl::reportable_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::reportable_test_count); } // Gets the number of all tests. int UnitTestImpl::total_test_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::total_test_count); } // Gets the number of tests that should run. int UnitTestImpl::test_to_run_count() const { return SumOverTestSuiteList(test_suites_, &TestSuite::test_to_run_count); } // Returns the current OS stack trace as an std::string. // // The maximum number of stack frames to be included is specified by // the gtest_stack_trace_depth flag. The skip_count parameter // specifies the number of top frames to be skipped, which doesn't // count against the number of frames to be included. // // For example, if Foo() calls Bar(), which in turn calls // CurrentOsStackTraceExceptTop(1), Foo() will be included in the // trace but Bar() and CurrentOsStackTraceExceptTop() won't. std::string UnitTestImpl::CurrentOsStackTraceExceptTop(int skip_count) { return os_stack_trace_getter()->CurrentStackTrace( static_cast(GTEST_FLAG(stack_trace_depth)), skip_count + 1 // Skips the user-specified number of frames plus this function // itself. ); // NOLINT } // Returns the current time in milliseconds. TimeInMillis GetTimeInMillis() { #if GTEST_OS_WINDOWS_MOBILE || defined(__BORLANDC__) // Difference between 1970-01-01 and 1601-01-01 in milliseconds. // http://analogous.blogspot.com/2005/04/epoch.html const TimeInMillis kJavaEpochToWinFileTimeDelta = static_cast(116444736UL) * 100000UL; const DWORD kTenthMicrosInMilliSecond = 10000; SYSTEMTIME now_systime; FILETIME now_filetime; ULARGE_INTEGER now_int64; GetSystemTime(&now_systime); if (SystemTimeToFileTime(&now_systime, &now_filetime)) { now_int64.LowPart = now_filetime.dwLowDateTime; now_int64.HighPart = now_filetime.dwHighDateTime; now_int64.QuadPart = (now_int64.QuadPart / kTenthMicrosInMilliSecond) - kJavaEpochToWinFileTimeDelta; return now_int64.QuadPart; } return 0; #elif GTEST_OS_WINDOWS && !GTEST_HAS_GETTIMEOFDAY_ __timeb64 now; // MSVC 8 deprecates _ftime64(), so we want to suppress warning 4996 // (deprecated function) there. GTEST_DISABLE_MSC_DEPRECATED_PUSH_() _ftime64(&now); GTEST_DISABLE_MSC_DEPRECATED_POP_() return static_cast(now.time) * 1000 + now.millitm; #elif GTEST_HAS_GETTIMEOFDAY_ struct timeval now; gettimeofday(&now, nullptr); return static_cast(now.tv_sec) * 1000 + now.tv_usec / 1000; #else # error "Don't know how to get the current time on your system." #endif } // Utilities // class String. #if GTEST_OS_WINDOWS_MOBILE // Creates a UTF-16 wide string from the given ANSI string, allocating // memory using new. The caller is responsible for deleting the return // value using delete[]. Returns the wide string, or NULL if the // input is NULL. LPCWSTR String::AnsiToUtf16(const char* ansi) { if (!ansi) return nullptr; const int length = strlen(ansi); const int unicode_length = MultiByteToWideChar(CP_ACP, 0, ansi, length, nullptr, 0); WCHAR* unicode = new WCHAR[unicode_length + 1]; MultiByteToWideChar(CP_ACP, 0, ansi, length, unicode, unicode_length); unicode[unicode_length] = 0; return unicode; } // Creates an ANSI string from the given wide string, allocating // memory using new. The caller is responsible for deleting the return // value using delete[]. Returns the ANSI string, or NULL if the // input is NULL. const char* String::Utf16ToAnsi(LPCWSTR utf16_str) { if (!utf16_str) return nullptr; const int ansi_length = WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, nullptr, 0, nullptr, nullptr); char* ansi = new char[ansi_length + 1]; WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, ansi, ansi_length, nullptr, nullptr); ansi[ansi_length] = 0; return ansi; } #endif // GTEST_OS_WINDOWS_MOBILE // Compares two C strings. Returns true if and only if they have the same // content. // // Unlike strcmp(), this function can handle NULL argument(s). A NULL // C string is considered different to any non-NULL C string, // including the empty string. bool String::CStringEquals(const char * lhs, const char * rhs) { if (lhs == nullptr) return rhs == nullptr; if (rhs == nullptr) return false; return strcmp(lhs, rhs) == 0; } #if GTEST_HAS_STD_WSTRING // Converts an array of wide chars to a narrow string using the UTF-8 // encoding, and streams the result to the given Message object. static void StreamWideCharsToMessage(const wchar_t* wstr, size_t length, Message* msg) { for (size_t i = 0; i != length; ) { // NOLINT if (wstr[i] != L'\0') { *msg << WideStringToUtf8(wstr + i, static_cast(length - i)); while (i != length && wstr[i] != L'\0') i++; } else { *msg << '\0'; i++; } } } #endif // GTEST_HAS_STD_WSTRING void SplitString(const ::std::string& str, char delimiter, ::std::vector< ::std::string>* dest) { ::std::vector< ::std::string> parsed; ::std::string::size_type pos = 0; while (::testing::internal::AlwaysTrue()) { const ::std::string::size_type colon = str.find(delimiter, pos); if (colon == ::std::string::npos) { parsed.push_back(str.substr(pos)); break; } else { parsed.push_back(str.substr(pos, colon - pos)); pos = colon + 1; } } dest->swap(parsed); } } // namespace internal // Constructs an empty Message. // We allocate the stringstream separately because otherwise each use of // ASSERT/EXPECT in a procedure adds over 200 bytes to the procedure's // stack frame leading to huge stack frames in some cases; gcc does not reuse // the stack space. Message::Message() : ss_(new ::std::stringstream) { // By default, we want there to be enough precision when printing // a double to a Message. *ss_ << std::setprecision(std::numeric_limits::digits10 + 2); } // These two overloads allow streaming a wide C string to a Message // using the UTF-8 encoding. Message& Message::operator <<(const wchar_t* wide_c_str) { return *this << internal::String::ShowWideCString(wide_c_str); } Message& Message::operator <<(wchar_t* wide_c_str) { return *this << internal::String::ShowWideCString(wide_c_str); } #if GTEST_HAS_STD_WSTRING // Converts the given wide string to a narrow string using the UTF-8 // encoding, and streams the result to this Message object. Message& Message::operator <<(const ::std::wstring& wstr) { internal::StreamWideCharsToMessage(wstr.c_str(), wstr.length(), this); return *this; } #endif // GTEST_HAS_STD_WSTRING // Gets the text streamed to this object so far as an std::string. // Each '\0' character in the buffer is replaced with "\\0". std::string Message::GetString() const { return internal::StringStreamToString(ss_.get()); } // AssertionResult constructors. // Used in EXPECT_TRUE/FALSE(assertion_result). AssertionResult::AssertionResult(const AssertionResult& other) : success_(other.success_), message_(other.message_.get() != nullptr ? new ::std::string(*other.message_) : static_cast< ::std::string*>(nullptr)) {} // Swaps two AssertionResults. void AssertionResult::swap(AssertionResult& other) { using std::swap; swap(success_, other.success_); swap(message_, other.message_); } // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. AssertionResult AssertionResult::operator!() const { AssertionResult negation(!success_); if (message_.get() != nullptr) negation << *message_; return negation; } // Makes a successful assertion result. AssertionResult AssertionSuccess() { return AssertionResult(true); } // Makes a failed assertion result. AssertionResult AssertionFailure() { return AssertionResult(false); } // Makes a failed assertion result with the given failure message. // Deprecated; use AssertionFailure() << message. AssertionResult AssertionFailure(const Message& message) { return AssertionFailure() << message; } namespace internal { namespace edit_distance { std::vector CalculateOptimalEdits(const std::vector& left, const std::vector& right) { std::vector > costs( left.size() + 1, std::vector(right.size() + 1)); std::vector > best_move( left.size() + 1, std::vector(right.size() + 1)); // Populate for empty right. for (size_t l_i = 0; l_i < costs.size(); ++l_i) { costs[l_i][0] = static_cast(l_i); best_move[l_i][0] = kRemove; } // Populate for empty left. for (size_t r_i = 1; r_i < costs[0].size(); ++r_i) { costs[0][r_i] = static_cast(r_i); best_move[0][r_i] = kAdd; } for (size_t l_i = 0; l_i < left.size(); ++l_i) { for (size_t r_i = 0; r_i < right.size(); ++r_i) { if (left[l_i] == right[r_i]) { // Found a match. Consume it. costs[l_i + 1][r_i + 1] = costs[l_i][r_i]; best_move[l_i + 1][r_i + 1] = kMatch; continue; } const double add = costs[l_i + 1][r_i]; const double remove = costs[l_i][r_i + 1]; const double replace = costs[l_i][r_i]; if (add < remove && add < replace) { costs[l_i + 1][r_i + 1] = add + 1; best_move[l_i + 1][r_i + 1] = kAdd; } else if (remove < add && remove < replace) { costs[l_i + 1][r_i + 1] = remove + 1; best_move[l_i + 1][r_i + 1] = kRemove; } else { // We make replace a little more expensive than add/remove to lower // their priority. costs[l_i + 1][r_i + 1] = replace + 1.00001; best_move[l_i + 1][r_i + 1] = kReplace; } } } // Reconstruct the best path. We do it in reverse order. std::vector best_path; for (size_t l_i = left.size(), r_i = right.size(); l_i > 0 || r_i > 0;) { EditType move = best_move[l_i][r_i]; best_path.push_back(move); l_i -= move != kAdd; r_i -= move != kRemove; } std::reverse(best_path.begin(), best_path.end()); return best_path; } namespace { // Helper class to convert string into ids with deduplication. class InternalStrings { public: size_t GetId(const std::string& str) { IdMap::iterator it = ids_.find(str); if (it != ids_.end()) return it->second; size_t id = ids_.size(); return ids_[str] = id; } private: typedef std::map IdMap; IdMap ids_; }; } // namespace std::vector CalculateOptimalEdits( const std::vector& left, const std::vector& right) { std::vector left_ids, right_ids; { InternalStrings intern_table; for (size_t i = 0; i < left.size(); ++i) { left_ids.push_back(intern_table.GetId(left[i])); } for (size_t i = 0; i < right.size(); ++i) { right_ids.push_back(intern_table.GetId(right[i])); } } return CalculateOptimalEdits(left_ids, right_ids); } namespace { // Helper class that holds the state for one hunk and prints it out to the // stream. // It reorders adds/removes when possible to group all removes before all // adds. It also adds the hunk header before printint into the stream. class Hunk { public: Hunk(size_t left_start, size_t right_start) : left_start_(left_start), right_start_(right_start), adds_(), removes_(), common_() {} void PushLine(char edit, const char* line) { switch (edit) { case ' ': ++common_; FlushEdits(); hunk_.push_back(std::make_pair(' ', line)); break; case '-': ++removes_; hunk_removes_.push_back(std::make_pair('-', line)); break; case '+': ++adds_; hunk_adds_.push_back(std::make_pair('+', line)); break; } } void PrintTo(std::ostream* os) { PrintHeader(os); FlushEdits(); for (std::list >::const_iterator it = hunk_.begin(); it != hunk_.end(); ++it) { *os << it->first << it->second << "\n"; } } bool has_edits() const { return adds_ || removes_; } private: void FlushEdits() { hunk_.splice(hunk_.end(), hunk_removes_); hunk_.splice(hunk_.end(), hunk_adds_); } // Print a unified diff header for one hunk. // The format is // "@@ -, +, @@" // where the left/right parts are omitted if unnecessary. void PrintHeader(std::ostream* ss) const { *ss << "@@ "; if (removes_) { *ss << "-" << left_start_ << "," << (removes_ + common_); } if (removes_ && adds_) { *ss << " "; } if (adds_) { *ss << "+" << right_start_ << "," << (adds_ + common_); } *ss << " @@\n"; } size_t left_start_, right_start_; size_t adds_, removes_, common_; std::list > hunk_, hunk_adds_, hunk_removes_; }; } // namespace // Create a list of diff hunks in Unified diff format. // Each hunk has a header generated by PrintHeader above plus a body with // lines prefixed with ' ' for no change, '-' for deletion and '+' for // addition. // 'context' represents the desired unchanged prefix/suffix around the diff. // If two hunks are close enough that their contexts overlap, then they are // joined into one hunk. std::string CreateUnifiedDiff(const std::vector& left, const std::vector& right, size_t context) { const std::vector edits = CalculateOptimalEdits(left, right); size_t l_i = 0, r_i = 0, edit_i = 0; std::stringstream ss; while (edit_i < edits.size()) { // Find first edit. while (edit_i < edits.size() && edits[edit_i] == kMatch) { ++l_i; ++r_i; ++edit_i; } // Find the first line to include in the hunk. const size_t prefix_context = std::min(l_i, context); Hunk hunk(l_i - prefix_context + 1, r_i - prefix_context + 1); for (size_t i = prefix_context; i > 0; --i) { hunk.PushLine(' ', left[l_i - i].c_str()); } // Iterate the edits until we found enough suffix for the hunk or the input // is over. size_t n_suffix = 0; for (; edit_i < edits.size(); ++edit_i) { if (n_suffix >= context) { // Continue only if the next hunk is very close. auto it = edits.begin() + static_cast(edit_i); while (it != edits.end() && *it == kMatch) ++it; if (it == edits.end() || static_cast(it - edits.begin()) - edit_i >= context) { // There is no next edit or it is too far away. break; } } EditType edit = edits[edit_i]; // Reset count when a non match is found. n_suffix = edit == kMatch ? n_suffix + 1 : 0; if (edit == kMatch || edit == kRemove || edit == kReplace) { hunk.PushLine(edit == kMatch ? ' ' : '-', left[l_i].c_str()); } if (edit == kAdd || edit == kReplace) { hunk.PushLine('+', right[r_i].c_str()); } // Advance indices, depending on edit type. l_i += edit != kAdd; r_i += edit != kRemove; } if (!hunk.has_edits()) { // We are done. We don't want this hunk. break; } hunk.PrintTo(&ss); } return ss.str(); } } // namespace edit_distance namespace { // The string representation of the values received in EqFailure() are already // escaped. Split them on escaped '\n' boundaries. Leave all other escaped // characters the same. std::vector SplitEscapedString(const std::string& str) { std::vector lines; size_t start = 0, end = str.size(); if (end > 2 && str[0] == '"' && str[end - 1] == '"') { ++start; --end; } bool escaped = false; for (size_t i = start; i + 1 < end; ++i) { if (escaped) { escaped = false; if (str[i] == 'n') { lines.push_back(str.substr(start, i - start - 1)); start = i + 1; } } else { escaped = str[i] == '\\'; } } lines.push_back(str.substr(start, end - start)); return lines; } } // namespace // Constructs and returns the message for an equality assertion // (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. // // The first four parameters are the expressions used in the assertion // and their values, as strings. For example, for ASSERT_EQ(foo, bar) // where foo is 5 and bar is 6, we have: // // lhs_expression: "foo" // rhs_expression: "bar" // lhs_value: "5" // rhs_value: "6" // // The ignoring_case parameter is true if and only if the assertion is a // *_STRCASEEQ*. When it's true, the string "Ignoring case" will // be inserted into the message. AssertionResult EqFailure(const char* lhs_expression, const char* rhs_expression, const std::string& lhs_value, const std::string& rhs_value, bool ignoring_case) { Message msg; msg << "Expected equality of these values:"; msg << "\n " << lhs_expression; if (lhs_value != lhs_expression) { msg << "\n Which is: " << lhs_value; } msg << "\n " << rhs_expression; if (rhs_value != rhs_expression) { msg << "\n Which is: " << rhs_value; } if (ignoring_case) { msg << "\nIgnoring case"; } if (!lhs_value.empty() && !rhs_value.empty()) { const std::vector lhs_lines = SplitEscapedString(lhs_value); const std::vector rhs_lines = SplitEscapedString(rhs_value); if (lhs_lines.size() > 1 || rhs_lines.size() > 1) { msg << "\nWith diff:\n" << edit_distance::CreateUnifiedDiff(lhs_lines, rhs_lines); } } return AssertionFailure() << msg; } // Constructs a failure message for Boolean assertions such as EXPECT_TRUE. std::string GetBoolAssertionFailureMessage( const AssertionResult& assertion_result, const char* expression_text, const char* actual_predicate_value, const char* expected_predicate_value) { const char* actual_message = assertion_result.message(); Message msg; msg << "Value of: " << expression_text << "\n Actual: " << actual_predicate_value; if (actual_message[0] != '\0') msg << " (" << actual_message << ")"; msg << "\nExpected: " << expected_predicate_value; return msg.GetString(); } // Helper function for implementing ASSERT_NEAR. AssertionResult DoubleNearPredFormat(const char* expr1, const char* expr2, const char* abs_error_expr, double val1, double val2, double abs_error) { const double diff = fabs(val1 - val2); if (diff <= abs_error) return AssertionSuccess(); return AssertionFailure() << "The difference between " << expr1 << " and " << expr2 << " is " << diff << ", which exceeds " << abs_error_expr << ", where\n" << expr1 << " evaluates to " << val1 << ",\n" << expr2 << " evaluates to " << val2 << ", and\n" << abs_error_expr << " evaluates to " << abs_error << "."; } // Helper template for implementing FloatLE() and DoubleLE(). template AssertionResult FloatingPointLE(const char* expr1, const char* expr2, RawType val1, RawType val2) { // Returns success if val1 is less than val2, if (val1 < val2) { return AssertionSuccess(); } // or if val1 is almost equal to val2. const FloatingPoint lhs(val1), rhs(val2); if (lhs.AlmostEquals(rhs)) { return AssertionSuccess(); } // Note that the above two checks will both fail if either val1 or // val2 is NaN, as the IEEE floating-point standard requires that // any predicate involving a NaN must return false. ::std::stringstream val1_ss; val1_ss << std::setprecision(std::numeric_limits::digits10 + 2) << val1; ::std::stringstream val2_ss; val2_ss << std::setprecision(std::numeric_limits::digits10 + 2) << val2; return AssertionFailure() << "Expected: (" << expr1 << ") <= (" << expr2 << ")\n" << " Actual: " << StringStreamToString(&val1_ss) << " vs " << StringStreamToString(&val2_ss); } } // namespace internal // Asserts that val1 is less than, or almost equal to, val2. Fails // otherwise. In particular, it fails if either val1 or val2 is NaN. AssertionResult FloatLE(const char* expr1, const char* expr2, float val1, float val2) { return internal::FloatingPointLE(expr1, expr2, val1, val2); } // Asserts that val1 is less than, or almost equal to, val2. Fails // otherwise. In particular, it fails if either val1 or val2 is NaN. AssertionResult DoubleLE(const char* expr1, const char* expr2, double val1, double val2) { return internal::FloatingPointLE(expr1, expr2, val1, val2); } namespace internal { // The helper function for {ASSERT|EXPECT}_EQ with int or enum // arguments. AssertionResult CmpHelperEQ(const char* lhs_expression, const char* rhs_expression, BiggestInt lhs, BiggestInt rhs) { if (lhs == rhs) { return AssertionSuccess(); } return EqFailure(lhs_expression, rhs_expression, FormatForComparisonFailureMessage(lhs, rhs), FormatForComparisonFailureMessage(rhs, lhs), false); } // A macro for implementing the helper functions needed to implement // ASSERT_?? and EXPECT_?? with integer or enum arguments. It is here // just to avoid copy-and-paste of similar code. #define GTEST_IMPL_CMP_HELPER_(op_name, op)\ AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ BiggestInt val1, BiggestInt val2) {\ if (val1 op val2) {\ return AssertionSuccess();\ } else {\ return AssertionFailure() \ << "Expected: (" << expr1 << ") " #op " (" << expr2\ << "), actual: " << FormatForComparisonFailureMessage(val1, val2)\ << " vs " << FormatForComparisonFailureMessage(val2, val1);\ }\ } // Implements the helper function for {ASSERT|EXPECT}_NE with int or // enum arguments. GTEST_IMPL_CMP_HELPER_(NE, !=) // Implements the helper function for {ASSERT|EXPECT}_LE with int or // enum arguments. GTEST_IMPL_CMP_HELPER_(LE, <=) // Implements the helper function for {ASSERT|EXPECT}_LT with int or // enum arguments. GTEST_IMPL_CMP_HELPER_(LT, < ) // Implements the helper function for {ASSERT|EXPECT}_GE with int or // enum arguments. GTEST_IMPL_CMP_HELPER_(GE, >=) // Implements the helper function for {ASSERT|EXPECT}_GT with int or // enum arguments. GTEST_IMPL_CMP_HELPER_(GT, > ) #undef GTEST_IMPL_CMP_HELPER_ // The helper function for {ASSERT|EXPECT}_STREQ. AssertionResult CmpHelperSTREQ(const char* lhs_expression, const char* rhs_expression, const char* lhs, const char* rhs) { if (String::CStringEquals(lhs, rhs)) { return AssertionSuccess(); } return EqFailure(lhs_expression, rhs_expression, PrintToString(lhs), PrintToString(rhs), false); } // The helper function for {ASSERT|EXPECT}_STRCASEEQ. AssertionResult CmpHelperSTRCASEEQ(const char* lhs_expression, const char* rhs_expression, const char* lhs, const char* rhs) { if (String::CaseInsensitiveCStringEquals(lhs, rhs)) { return AssertionSuccess(); } return EqFailure(lhs_expression, rhs_expression, PrintToString(lhs), PrintToString(rhs), true); } // The helper function for {ASSERT|EXPECT}_STRNE. AssertionResult CmpHelperSTRNE(const char* s1_expression, const char* s2_expression, const char* s1, const char* s2) { if (!String::CStringEquals(s1, s2)) { return AssertionSuccess(); } else { return AssertionFailure() << "Expected: (" << s1_expression << ") != (" << s2_expression << "), actual: \"" << s1 << "\" vs \"" << s2 << "\""; } } // The helper function for {ASSERT|EXPECT}_STRCASENE. AssertionResult CmpHelperSTRCASENE(const char* s1_expression, const char* s2_expression, const char* s1, const char* s2) { if (!String::CaseInsensitiveCStringEquals(s1, s2)) { return AssertionSuccess(); } else { return AssertionFailure() << "Expected: (" << s1_expression << ") != (" << s2_expression << ") (ignoring case), actual: \"" << s1 << "\" vs \"" << s2 << "\""; } } } // namespace internal namespace { // Helper functions for implementing IsSubString() and IsNotSubstring(). // This group of overloaded functions return true if and only if needle // is a substring of haystack. NULL is considered a substring of // itself only. bool IsSubstringPred(const char* needle, const char* haystack) { if (needle == nullptr || haystack == nullptr) return needle == haystack; return strstr(haystack, needle) != nullptr; } bool IsSubstringPred(const wchar_t* needle, const wchar_t* haystack) { if (needle == nullptr || haystack == nullptr) return needle == haystack; return wcsstr(haystack, needle) != nullptr; } // StringType here can be either ::std::string or ::std::wstring. template bool IsSubstringPred(const StringType& needle, const StringType& haystack) { return haystack.find(needle) != StringType::npos; } // This function implements either IsSubstring() or IsNotSubstring(), // depending on the value of the expected_to_be_substring parameter. // StringType here can be const char*, const wchar_t*, ::std::string, // or ::std::wstring. template AssertionResult IsSubstringImpl( bool expected_to_be_substring, const char* needle_expr, const char* haystack_expr, const StringType& needle, const StringType& haystack) { if (IsSubstringPred(needle, haystack) == expected_to_be_substring) return AssertionSuccess(); const bool is_wide_string = sizeof(needle[0]) > 1; const char* const begin_string_quote = is_wide_string ? "L\"" : "\""; return AssertionFailure() << "Value of: " << needle_expr << "\n" << " Actual: " << begin_string_quote << needle << "\"\n" << "Expected: " << (expected_to_be_substring ? "" : "not ") << "a substring of " << haystack_expr << "\n" << "Which is: " << begin_string_quote << haystack << "\""; } } // namespace // IsSubstring() and IsNotSubstring() check whether needle is a // substring of haystack (NULL is considered a substring of itself // only), and return an appropriate error message when they fail. AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const char* needle, const char* haystack) { return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); } AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const wchar_t* needle, const wchar_t* haystack) { return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); } AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const char* needle, const char* haystack) { return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); } AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const wchar_t* needle, const wchar_t* haystack) { return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); } AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const ::std::string& needle, const ::std::string& haystack) { return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); } AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const ::std::string& needle, const ::std::string& haystack) { return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); } #if GTEST_HAS_STD_WSTRING AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const ::std::wstring& needle, const ::std::wstring& haystack) { return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); } AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const ::std::wstring& needle, const ::std::wstring& haystack) { return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); } #endif // GTEST_HAS_STD_WSTRING namespace internal { #if GTEST_OS_WINDOWS namespace { // Helper function for IsHRESULT{SuccessFailure} predicates AssertionResult HRESULTFailureHelper(const char* expr, const char* expected, long hr) { // NOLINT # if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_TV_TITLE // Windows CE doesn't support FormatMessage. const char error_text[] = ""; # else // Looks up the human-readable system message for the HRESULT code // and since we're not passing any params to FormatMessage, we don't // want inserts expanded. const DWORD kFlags = FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS; const DWORD kBufSize = 4096; // Gets the system's human readable message string for this HRESULT. char error_text[kBufSize] = { '\0' }; DWORD message_length = ::FormatMessageA(kFlags, 0, // no source, we're asking system static_cast(hr), // the error 0, // no line width restrictions error_text, // output buffer kBufSize, // buf size nullptr); // no arguments for inserts // Trims tailing white space (FormatMessage leaves a trailing CR-LF) for (; message_length && IsSpace(error_text[message_length - 1]); --message_length) { error_text[message_length - 1] = '\0'; } # endif // GTEST_OS_WINDOWS_MOBILE const std::string error_hex("0x" + String::FormatHexInt(hr)); return ::testing::AssertionFailure() << "Expected: " << expr << " " << expected << ".\n" << " Actual: " << error_hex << " " << error_text << "\n"; } } // namespace AssertionResult IsHRESULTSuccess(const char* expr, long hr) { // NOLINT if (SUCCEEDED(hr)) { return AssertionSuccess(); } return HRESULTFailureHelper(expr, "succeeds", hr); } AssertionResult IsHRESULTFailure(const char* expr, long hr) { // NOLINT if (FAILED(hr)) { return AssertionSuccess(); } return HRESULTFailureHelper(expr, "fails", hr); } #endif // GTEST_OS_WINDOWS // Utility functions for encoding Unicode text (wide strings) in // UTF-8. // A Unicode code-point can have up to 21 bits, and is encoded in UTF-8 // like this: // // Code-point length Encoding // 0 - 7 bits 0xxxxxxx // 8 - 11 bits 110xxxxx 10xxxxxx // 12 - 16 bits 1110xxxx 10xxxxxx 10xxxxxx // 17 - 21 bits 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx // The maximum code-point a one-byte UTF-8 sequence can represent. const UInt32 kMaxCodePoint1 = (static_cast(1) << 7) - 1; // The maximum code-point a two-byte UTF-8 sequence can represent. const UInt32 kMaxCodePoint2 = (static_cast(1) << (5 + 6)) - 1; // The maximum code-point a three-byte UTF-8 sequence can represent. const UInt32 kMaxCodePoint3 = (static_cast(1) << (4 + 2*6)) - 1; // The maximum code-point a four-byte UTF-8 sequence can represent. const UInt32 kMaxCodePoint4 = (static_cast(1) << (3 + 3*6)) - 1; // Chops off the n lowest bits from a bit pattern. Returns the n // lowest bits. As a side effect, the original bit pattern will be // shifted to the right by n bits. inline UInt32 ChopLowBits(UInt32* bits, int n) { const UInt32 low_bits = *bits & ((static_cast(1) << n) - 1); *bits >>= n; return low_bits; } // Converts a Unicode code point to a narrow string in UTF-8 encoding. // code_point parameter is of type UInt32 because wchar_t may not be // wide enough to contain a code point. // If the code_point is not a valid Unicode code point // (i.e. outside of Unicode range U+0 to U+10FFFF) it will be converted // to "(Invalid Unicode 0xXXXXXXXX)". std::string CodePointToUtf8(UInt32 code_point) { if (code_point > kMaxCodePoint4) { return "(Invalid Unicode 0x" + String::FormatHexUInt32(code_point) + ")"; } char str[5]; // Big enough for the largest valid code point. if (code_point <= kMaxCodePoint1) { str[1] = '\0'; str[0] = static_cast(code_point); // 0xxxxxxx } else if (code_point <= kMaxCodePoint2) { str[2] = '\0'; str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx str[0] = static_cast(0xC0 | code_point); // 110xxxxx } else if (code_point <= kMaxCodePoint3) { str[3] = '\0'; str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx str[0] = static_cast(0xE0 | code_point); // 1110xxxx } else { // code_point <= kMaxCodePoint4 str[4] = '\0'; str[3] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx str[0] = static_cast(0xF0 | code_point); // 11110xxx } return str; } // The following two functions only make sense if the system // uses UTF-16 for wide string encoding. All supported systems // with 16 bit wchar_t (Windows, Cygwin) do use UTF-16. // Determines if the arguments constitute UTF-16 surrogate pair // and thus should be combined into a single Unicode code point // using CreateCodePointFromUtf16SurrogatePair. inline bool IsUtf16SurrogatePair(wchar_t first, wchar_t second) { return sizeof(wchar_t) == 2 && (first & 0xFC00) == 0xD800 && (second & 0xFC00) == 0xDC00; } // Creates a Unicode code point from UTF16 surrogate pair. inline UInt32 CreateCodePointFromUtf16SurrogatePair(wchar_t first, wchar_t second) { const auto first_u = static_cast(first); const auto second_u = static_cast(second); const UInt32 mask = (1 << 10) - 1; return (sizeof(wchar_t) == 2) ? (((first_u & mask) << 10) | (second_u & mask)) + 0x10000 : // This function should not be called when the condition is // false, but we provide a sensible default in case it is. first_u; } // Converts a wide string to a narrow string in UTF-8 encoding. // The wide string is assumed to have the following encoding: // UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin) // UTF-32 if sizeof(wchar_t) == 4 (on Linux) // Parameter str points to a null-terminated wide string. // Parameter num_chars may additionally limit the number // of wchar_t characters processed. -1 is used when the entire string // should be processed. // If the string contains code points that are not valid Unicode code points // (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output // as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding // and contains invalid UTF-16 surrogate pairs, values in those pairs // will be encoded as individual Unicode characters from Basic Normal Plane. std::string WideStringToUtf8(const wchar_t* str, int num_chars) { if (num_chars == -1) num_chars = static_cast(wcslen(str)); ::std::stringstream stream; for (int i = 0; i < num_chars; ++i) { UInt32 unicode_code_point; if (str[i] == L'\0') { break; } else if (i + 1 < num_chars && IsUtf16SurrogatePair(str[i], str[i + 1])) { unicode_code_point = CreateCodePointFromUtf16SurrogatePair(str[i], str[i + 1]); i++; } else { unicode_code_point = static_cast(str[i]); } stream << CodePointToUtf8(unicode_code_point); } return StringStreamToString(&stream); } // Converts a wide C string to an std::string using the UTF-8 encoding. // NULL will be converted to "(null)". std::string String::ShowWideCString(const wchar_t * wide_c_str) { if (wide_c_str == nullptr) return "(null)"; return internal::WideStringToUtf8(wide_c_str, -1); } // Compares two wide C strings. Returns true if and only if they have the // same content. // // Unlike wcscmp(), this function can handle NULL argument(s). A NULL // C string is considered different to any non-NULL C string, // including the empty string. bool String::WideCStringEquals(const wchar_t * lhs, const wchar_t * rhs) { if (lhs == nullptr) return rhs == nullptr; if (rhs == nullptr) return false; return wcscmp(lhs, rhs) == 0; } // Helper function for *_STREQ on wide strings. AssertionResult CmpHelperSTREQ(const char* lhs_expression, const char* rhs_expression, const wchar_t* lhs, const wchar_t* rhs) { if (String::WideCStringEquals(lhs, rhs)) { return AssertionSuccess(); } return EqFailure(lhs_expression, rhs_expression, PrintToString(lhs), PrintToString(rhs), false); } // Helper function for *_STRNE on wide strings. AssertionResult CmpHelperSTRNE(const char* s1_expression, const char* s2_expression, const wchar_t* s1, const wchar_t* s2) { if (!String::WideCStringEquals(s1, s2)) { return AssertionSuccess(); } return AssertionFailure() << "Expected: (" << s1_expression << ") != (" << s2_expression << "), actual: " << PrintToString(s1) << " vs " << PrintToString(s2); } // Compares two C strings, ignoring case. Returns true if and only if they have // the same content. // // Unlike strcasecmp(), this function can handle NULL argument(s). A // NULL C string is considered different to any non-NULL C string, // including the empty string. bool String::CaseInsensitiveCStringEquals(const char * lhs, const char * rhs) { if (lhs == nullptr) return rhs == nullptr; if (rhs == nullptr) return false; return posix::StrCaseCmp(lhs, rhs) == 0; } // Compares two wide C strings, ignoring case. Returns true if and only if they // have the same content. // // Unlike wcscasecmp(), this function can handle NULL argument(s). // A NULL C string is considered different to any non-NULL wide C string, // including the empty string. // NB: The implementations on different platforms slightly differ. // On windows, this method uses _wcsicmp which compares according to LC_CTYPE // environment variable. On GNU platform this method uses wcscasecmp // which compares according to LC_CTYPE category of the current locale. // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the // current locale. bool String::CaseInsensitiveWideCStringEquals(const wchar_t* lhs, const wchar_t* rhs) { if (lhs == nullptr) return rhs == nullptr; if (rhs == nullptr) return false; #if GTEST_OS_WINDOWS return _wcsicmp(lhs, rhs) == 0; #elif GTEST_OS_LINUX && !GTEST_OS_LINUX_ANDROID return wcscasecmp(lhs, rhs) == 0; #else // Android, Mac OS X and Cygwin don't define wcscasecmp. // Other unknown OSes may not define it either. wint_t left, right; do { left = towlower(static_cast(*lhs++)); right = towlower(static_cast(*rhs++)); } while (left && left == right); return left == right; #endif // OS selector } // Returns true if and only if str ends with the given suffix, ignoring case. // Any string is considered to end with an empty suffix. bool String::EndsWithCaseInsensitive( const std::string& str, const std::string& suffix) { const size_t str_len = str.length(); const size_t suffix_len = suffix.length(); return (str_len >= suffix_len) && CaseInsensitiveCStringEquals(str.c_str() + str_len - suffix_len, suffix.c_str()); } // Formats an int value as "%02d". std::string String::FormatIntWidth2(int value) { std::stringstream ss; ss << std::setfill('0') << std::setw(2) << value; return ss.str(); } // Formats an int value as "%X". std::string String::FormatHexUInt32(UInt32 value) { std::stringstream ss; ss << std::hex << std::uppercase << value; return ss.str(); } // Formats an int value as "%X". std::string String::FormatHexInt(int value) { return FormatHexUInt32(static_cast(value)); } // Formats a byte as "%02X". std::string String::FormatByte(unsigned char value) { std::stringstream ss; ss << std::setfill('0') << std::setw(2) << std::hex << std::uppercase << static_cast(value); return ss.str(); } // Converts the buffer in a stringstream to an std::string, converting NUL // bytes to "\\0" along the way. std::string StringStreamToString(::std::stringstream* ss) { const ::std::string& str = ss->str(); const char* const start = str.c_str(); const char* const end = start + str.length(); std::string result; result.reserve(static_cast(2 * (end - start))); for (const char* ch = start; ch != end; ++ch) { if (*ch == '\0') { result += "\\0"; // Replaces NUL with "\\0"; } else { result += *ch; } } return result; } // Appends the user-supplied message to the Google-Test-generated message. std::string AppendUserMessage(const std::string& gtest_msg, const Message& user_msg) { // Appends the user message if it's non-empty. const std::string user_msg_string = user_msg.GetString(); if (user_msg_string.empty()) { return gtest_msg; } return gtest_msg + "\n" + user_msg_string; } } // namespace internal // class TestResult // Creates an empty TestResult. TestResult::TestResult() : death_test_count_(0), start_timestamp_(0), elapsed_time_(0) {} // D'tor. TestResult::~TestResult() { } // Returns the i-th test part result among all the results. i can // range from 0 to total_part_count() - 1. If i is not in that range, // aborts the program. const TestPartResult& TestResult::GetTestPartResult(int i) const { if (i < 0 || i >= total_part_count()) internal::posix::Abort(); return test_part_results_.at(static_cast(i)); } // Returns the i-th test property. i can range from 0 to // test_property_count() - 1. If i is not in that range, aborts the // program. const TestProperty& TestResult::GetTestProperty(int i) const { if (i < 0 || i >= test_property_count()) internal::posix::Abort(); return test_properties_.at(static_cast(i)); } // Clears the test part results. void TestResult::ClearTestPartResults() { test_part_results_.clear(); } // Adds a test part result to the list. void TestResult::AddTestPartResult(const TestPartResult& test_part_result) { test_part_results_.push_back(test_part_result); } // Adds a test property to the list. If a property with the same key as the // supplied property is already represented, the value of this test_property // replaces the old value for that key. void TestResult::RecordProperty(const std::string& xml_element, const TestProperty& test_property) { if (!ValidateTestProperty(xml_element, test_property)) { return; } internal::MutexLock lock(&test_properites_mutex_); const std::vector::iterator property_with_matching_key = std::find_if(test_properties_.begin(), test_properties_.end(), internal::TestPropertyKeyIs(test_property.key())); if (property_with_matching_key == test_properties_.end()) { test_properties_.push_back(test_property); return; } property_with_matching_key->SetValue(test_property.value()); } // The list of reserved attributes used in the element of XML // output. static const char* const kReservedTestSuitesAttributes[] = { "disabled", "errors", "failures", "name", "random_seed", "tests", "time", "timestamp" }; // The list of reserved attributes used in the element of XML // output. static const char* const kReservedTestSuiteAttributes[] = { "disabled", "errors", "failures", "name", "tests", "time", "timestamp"}; // The list of reserved attributes used in the element of XML output. static const char* const kReservedTestCaseAttributes[] = { "classname", "name", "status", "time", "type_param", "value_param", "file", "line"}; // Use a slightly different set for allowed output to ensure existing tests can // still RecordProperty("result") or "RecordProperty(timestamp") static const char* const kReservedOutputTestCaseAttributes[] = { "classname", "name", "status", "time", "type_param", "value_param", "file", "line", "result", "timestamp"}; template std::vector ArrayAsVector(const char* const (&array)[kSize]) { return std::vector(array, array + kSize); } static std::vector GetReservedAttributesForElement( const std::string& xml_element) { if (xml_element == "testsuites") { return ArrayAsVector(kReservedTestSuitesAttributes); } else if (xml_element == "testsuite") { return ArrayAsVector(kReservedTestSuiteAttributes); } else if (xml_element == "testcase") { return ArrayAsVector(kReservedTestCaseAttributes); } else { GTEST_CHECK_(false) << "Unrecognized xml_element provided: " << xml_element; } // This code is unreachable but some compilers may not realizes that. return std::vector(); } // TODO(jdesprez): Merge the two getReserved attributes once skip is improved static std::vector GetReservedOutputAttributesForElement( const std::string& xml_element) { if (xml_element == "testsuites") { return ArrayAsVector(kReservedTestSuitesAttributes); } else if (xml_element == "testsuite") { return ArrayAsVector(kReservedTestSuiteAttributes); } else if (xml_element == "testcase") { return ArrayAsVector(kReservedOutputTestCaseAttributes); } else { GTEST_CHECK_(false) << "Unrecognized xml_element provided: " << xml_element; } // This code is unreachable but some compilers may not realizes that. return std::vector(); } static std::string FormatWordList(const std::vector& words) { Message word_list; for (size_t i = 0; i < words.size(); ++i) { if (i > 0 && words.size() > 2) { word_list << ", "; } if (i == words.size() - 1) { word_list << "and "; } word_list << "'" << words[i] << "'"; } return word_list.GetString(); } static bool ValidateTestPropertyName( const std::string& property_name, const std::vector& reserved_names) { if (std::find(reserved_names.begin(), reserved_names.end(), property_name) != reserved_names.end()) { ADD_FAILURE() << "Reserved key used in RecordProperty(): " << property_name << " (" << FormatWordList(reserved_names) << " are reserved by " << GTEST_NAME_ << ")"; return false; } return true; } // Adds a failure if the key is a reserved attribute of the element named // xml_element. Returns true if the property is valid. bool TestResult::ValidateTestProperty(const std::string& xml_element, const TestProperty& test_property) { return ValidateTestPropertyName(test_property.key(), GetReservedAttributesForElement(xml_element)); } // Clears the object. void TestResult::Clear() { test_part_results_.clear(); test_properties_.clear(); death_test_count_ = 0; elapsed_time_ = 0; } // Returns true off the test part was skipped. static bool TestPartSkipped(const TestPartResult& result) { return result.skipped(); } // Returns true if and only if the test was skipped. bool TestResult::Skipped() const { return !Failed() && CountIf(test_part_results_, TestPartSkipped) > 0; } // Returns true if and only if the test failed. bool TestResult::Failed() const { for (int i = 0; i < total_part_count(); ++i) { if (GetTestPartResult(i).failed()) return true; } return false; } // Returns true if and only if the test part fatally failed. static bool TestPartFatallyFailed(const TestPartResult& result) { return result.fatally_failed(); } // Returns true if and only if the test fatally failed. bool TestResult::HasFatalFailure() const { return CountIf(test_part_results_, TestPartFatallyFailed) > 0; } // Returns true if and only if the test part non-fatally failed. static bool TestPartNonfatallyFailed(const TestPartResult& result) { return result.nonfatally_failed(); } // Returns true if and only if the test has a non-fatal failure. bool TestResult::HasNonfatalFailure() const { return CountIf(test_part_results_, TestPartNonfatallyFailed) > 0; } // Gets the number of all test parts. This is the sum of the number // of successful test parts and the number of failed test parts. int TestResult::total_part_count() const { return static_cast(test_part_results_.size()); } // Returns the number of the test properties. int TestResult::test_property_count() const { return static_cast(test_properties_.size()); } // class Test // Creates a Test object. // The c'tor saves the states of all flags. Test::Test() : gtest_flag_saver_(new GTEST_FLAG_SAVER_) { } // The d'tor restores the states of all flags. The actual work is // done by the d'tor of the gtest_flag_saver_ field, and thus not // visible here. Test::~Test() { } // Sets up the test fixture. // // A sub-class may override this. void Test::SetUp() { } // Tears down the test fixture. // // A sub-class may override this. void Test::TearDown() { } // Allows user supplied key value pairs to be recorded for later output. void Test::RecordProperty(const std::string& key, const std::string& value) { UnitTest::GetInstance()->RecordProperty(key, value); } // Allows user supplied key value pairs to be recorded for later output. void Test::RecordProperty(const std::string& key, int value) { Message value_message; value_message << value; RecordProperty(key, value_message.GetString().c_str()); } namespace internal { void ReportFailureInUnknownLocation(TestPartResult::Type result_type, const std::string& message) { // This function is a friend of UnitTest and as such has access to // AddTestPartResult. UnitTest::GetInstance()->AddTestPartResult( result_type, nullptr, // No info about the source file where the exception occurred. -1, // We have no info on which line caused the exception. message, ""); // No stack trace, either. } } // namespace internal // Google Test requires all tests in the same test suite to use the same test // fixture class. This function checks if the current test has the // same fixture class as the first test in the current test suite. If // yes, it returns true; otherwise it generates a Google Test failure and // returns false. bool Test::HasSameFixtureClass() { internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); const TestSuite* const test_suite = impl->current_test_suite(); // Info about the first test in the current test suite. const TestInfo* const first_test_info = test_suite->test_info_list()[0]; const internal::TypeId first_fixture_id = first_test_info->fixture_class_id_; const char* const first_test_name = first_test_info->name(); // Info about the current test. const TestInfo* const this_test_info = impl->current_test_info(); const internal::TypeId this_fixture_id = this_test_info->fixture_class_id_; const char* const this_test_name = this_test_info->name(); if (this_fixture_id != first_fixture_id) { // Is the first test defined using TEST? const bool first_is_TEST = first_fixture_id == internal::GetTestTypeId(); // Is this test defined using TEST? const bool this_is_TEST = this_fixture_id == internal::GetTestTypeId(); if (first_is_TEST || this_is_TEST) { // Both TEST and TEST_F appear in same test suite, which is incorrect. // Tell the user how to fix this. // Gets the name of the TEST and the name of the TEST_F. Note // that first_is_TEST and this_is_TEST cannot both be true, as // the fixture IDs are different for the two tests. const char* const TEST_name = first_is_TEST ? first_test_name : this_test_name; const char* const TEST_F_name = first_is_TEST ? this_test_name : first_test_name; ADD_FAILURE() << "All tests in the same test suite must use the same test fixture\n" << "class, so mixing TEST_F and TEST in the same test suite is\n" << "illegal. In test suite " << this_test_info->test_suite_name() << ",\n" << "test " << TEST_F_name << " is defined using TEST_F but\n" << "test " << TEST_name << " is defined using TEST. You probably\n" << "want to change the TEST to TEST_F or move it to another test\n" << "case."; } else { // Two fixture classes with the same name appear in two different // namespaces, which is not allowed. Tell the user how to fix this. ADD_FAILURE() << "All tests in the same test suite must use the same test fixture\n" << "class. However, in test suite " << this_test_info->test_suite_name() << ",\n" << "you defined test " << first_test_name << " and test " << this_test_name << "\n" << "using two different test fixture classes. This can happen if\n" << "the two classes are from different namespaces or translation\n" << "units and have the same name. You should probably rename one\n" << "of the classes to put the tests into different test suites."; } return false; } return true; } #if GTEST_HAS_SEH // Adds an "exception thrown" fatal failure to the current test. This // function returns its result via an output parameter pointer because VC++ // prohibits creation of objects with destructors on stack in functions // using __try (see error C2712). static std::string* FormatSehExceptionMessage(DWORD exception_code, const char* location) { Message message; message << "SEH exception with code 0x" << std::setbase(16) << exception_code << std::setbase(10) << " thrown in " << location << "."; return new std::string(message.GetString()); } #endif // GTEST_HAS_SEH namespace internal { #if GTEST_HAS_EXCEPTIONS // Adds an "exception thrown" fatal failure to the current test. static std::string FormatCxxExceptionMessage(const char* description, const char* location) { Message message; if (description != nullptr) { message << "C++ exception with description \"" << description << "\""; } else { message << "Unknown C++ exception"; } message << " thrown in " << location << "."; return message.GetString(); } static std::string PrintTestPartResultToString( const TestPartResult& test_part_result); GoogleTestFailureException::GoogleTestFailureException( const TestPartResult& failure) : ::std::runtime_error(PrintTestPartResultToString(failure).c_str()) {} #endif // GTEST_HAS_EXCEPTIONS // We put these helper functions in the internal namespace as IBM's xlC // compiler rejects the code if they were declared static. // Runs the given method and handles SEH exceptions it throws, when // SEH is supported; returns the 0-value for type Result in case of an // SEH exception. (Microsoft compilers cannot handle SEH and C++ // exceptions in the same function. Therefore, we provide a separate // wrapper function for handling SEH exceptions.) template Result HandleSehExceptionsInMethodIfSupported( T* object, Result (T::*method)(), const char* location) { #if GTEST_HAS_SEH __try { return (object->*method)(); } __except (internal::UnitTestOptions::GTestShouldProcessSEH( // NOLINT GetExceptionCode())) { // We create the exception message on the heap because VC++ prohibits // creation of objects with destructors on stack in functions using __try // (see error C2712). std::string* exception_message = FormatSehExceptionMessage( GetExceptionCode(), location); internal::ReportFailureInUnknownLocation(TestPartResult::kFatalFailure, *exception_message); delete exception_message; return static_cast(0); } #else (void)location; return (object->*method)(); #endif // GTEST_HAS_SEH } // Runs the given method and catches and reports C++ and/or SEH-style // exceptions, if they are supported; returns the 0-value for type // Result in case of an SEH exception. template Result HandleExceptionsInMethodIfSupported( T* object, Result (T::*method)(), const char* location) { // NOTE: The user code can affect the way in which Google Test handles // exceptions by setting GTEST_FLAG(catch_exceptions), but only before // RUN_ALL_TESTS() starts. It is technically possible to check the flag // after the exception is caught and either report or re-throw the // exception based on the flag's value: // // try { // // Perform the test method. // } catch (...) { // if (GTEST_FLAG(catch_exceptions)) // // Report the exception as failure. // else // throw; // Re-throws the original exception. // } // // However, the purpose of this flag is to allow the program to drop into // the debugger when the exception is thrown. On most platforms, once the // control enters the catch block, the exception origin information is // lost and the debugger will stop the program at the point of the // re-throw in this function -- instead of at the point of the original // throw statement in the code under test. For this reason, we perform // the check early, sacrificing the ability to affect Google Test's // exception handling in the method where the exception is thrown. if (internal::GetUnitTestImpl()->catch_exceptions()) { #if GTEST_HAS_EXCEPTIONS try { return HandleSehExceptionsInMethodIfSupported(object, method, location); } catch (const AssertionException&) { // NOLINT // This failure was reported already. } catch (const internal::GoogleTestFailureException&) { // NOLINT // This exception type can only be thrown by a failed Google // Test assertion with the intention of letting another testing // framework catch it. Therefore we just re-throw it. throw; } catch (const std::exception& e) { // NOLINT internal::ReportFailureInUnknownLocation( TestPartResult::kFatalFailure, FormatCxxExceptionMessage(e.what(), location)); } catch (...) { // NOLINT internal::ReportFailureInUnknownLocation( TestPartResult::kFatalFailure, FormatCxxExceptionMessage(nullptr, location)); } return static_cast(0); #else return HandleSehExceptionsInMethodIfSupported(object, method, location); #endif // GTEST_HAS_EXCEPTIONS } else { return (object->*method)(); } } } // namespace internal // Runs the test and updates the test result. void Test::Run() { if (!HasSameFixtureClass()) return; internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); impl->os_stack_trace_getter()->UponLeavingGTest(); internal::HandleExceptionsInMethodIfSupported(this, &Test::SetUp, "SetUp()"); // We will run the test only if SetUp() was successful and didn't call // GTEST_SKIP(). if (!HasFatalFailure() && !IsSkipped()) { impl->os_stack_trace_getter()->UponLeavingGTest(); internal::HandleExceptionsInMethodIfSupported( this, &Test::TestBody, "the test body"); } // However, we want to clean up as much as possible. Hence we will // always call TearDown(), even if SetUp() or the test body has // failed. impl->os_stack_trace_getter()->UponLeavingGTest(); internal::HandleExceptionsInMethodIfSupported( this, &Test::TearDown, "TearDown()"); } // Returns true if and only if the current test has a fatal failure. bool Test::HasFatalFailure() { return internal::GetUnitTestImpl()->current_test_result()->HasFatalFailure(); } // Returns true if and only if the current test has a non-fatal failure. bool Test::HasNonfatalFailure() { return internal::GetUnitTestImpl()->current_test_result()-> HasNonfatalFailure(); } // Returns true if and only if the current test was skipped. bool Test::IsSkipped() { return internal::GetUnitTestImpl()->current_test_result()->Skipped(); } // class TestInfo // Constructs a TestInfo object. It assumes ownership of the test factory // object. TestInfo::TestInfo(const std::string& a_test_suite_name, const std::string& a_name, const char* a_type_param, const char* a_value_param, internal::CodeLocation a_code_location, internal::TypeId fixture_class_id, internal::TestFactoryBase* factory) : test_suite_name_(a_test_suite_name), name_(a_name), type_param_(a_type_param ? new std::string(a_type_param) : nullptr), value_param_(a_value_param ? new std::string(a_value_param) : nullptr), location_(a_code_location), fixture_class_id_(fixture_class_id), should_run_(false), is_disabled_(false), matches_filter_(false), factory_(factory), result_() {} // Destructs a TestInfo object. TestInfo::~TestInfo() { delete factory_; } namespace internal { // Creates a new TestInfo object and registers it with Google Test; // returns the created object. // // Arguments: // // test_suite_name: name of the test suite // name: name of the test // type_param: the name of the test's type parameter, or NULL if // this is not a typed or a type-parameterized test. // value_param: text representation of the test's value parameter, // or NULL if this is not a value-parameterized test. // code_location: code location where the test is defined // fixture_class_id: ID of the test fixture class // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite // factory: pointer to the factory that creates a test object. // The newly created TestInfo instance will assume // ownership of the factory object. TestInfo* MakeAndRegisterTestInfo( const char* test_suite_name, const char* name, const char* type_param, const char* value_param, CodeLocation code_location, TypeId fixture_class_id, SetUpTestSuiteFunc set_up_tc, TearDownTestSuiteFunc tear_down_tc, TestFactoryBase* factory) { TestInfo* const test_info = new TestInfo(test_suite_name, name, type_param, value_param, code_location, fixture_class_id, factory); GetUnitTestImpl()->AddTestInfo(set_up_tc, tear_down_tc, test_info); return test_info; } void ReportInvalidTestSuiteType(const char* test_suite_name, CodeLocation code_location) { Message errors; errors << "Attempted redefinition of test suite " << test_suite_name << ".\n" << "All tests in the same test suite must use the same test fixture\n" << "class. However, in test suite " << test_suite_name << ", you tried\n" << "to define a test using a fixture class different from the one\n" << "used earlier. This can happen if the two fixture classes are\n" << "from different namespaces and have the same name. You should\n" << "probably rename one of the classes to put the tests into different\n" << "test suites."; GTEST_LOG_(ERROR) << FormatFileLocation(code_location.file.c_str(), code_location.line) << " " << errors.GetString(); } } // namespace internal namespace { // A predicate that checks the test name of a TestInfo against a known // value. // // This is used for implementation of the TestSuite class only. We put // it in the anonymous namespace to prevent polluting the outer // namespace. // // TestNameIs is copyable. class TestNameIs { public: // Constructor. // // TestNameIs has NO default constructor. explicit TestNameIs(const char* name) : name_(name) {} // Returns true if and only if the test name of test_info matches name_. bool operator()(const TestInfo * test_info) const { return test_info && test_info->name() == name_; } private: std::string name_; }; } // namespace namespace internal { // This method expands all parameterized tests registered with macros TEST_P // and INSTANTIATE_TEST_SUITE_P into regular tests and registers those. // This will be done just once during the program runtime. void UnitTestImpl::RegisterParameterizedTests() { if (!parameterized_tests_registered_) { parameterized_test_registry_.RegisterTests(); parameterized_tests_registered_ = true; } } } // namespace internal // Creates the test object, runs it, records its result, and then // deletes it. void TestInfo::Run() { if (!should_run_) return; // Tells UnitTest where to store test result. internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); impl->set_current_test_info(this); TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); // Notifies the unit test event listeners that a test is about to start. repeater->OnTestStart(*this); const TimeInMillis start = internal::GetTimeInMillis(); impl->os_stack_trace_getter()->UponLeavingGTest(); // Creates the test object. Test* const test = internal::HandleExceptionsInMethodIfSupported( factory_, &internal::TestFactoryBase::CreateTest, "the test fixture's constructor"); // Runs the test if the constructor didn't generate a fatal failure or invoke // GTEST_SKIP(). // Note that the object will not be null if (!Test::HasFatalFailure() && !Test::IsSkipped()) { // This doesn't throw as all user code that can throw are wrapped into // exception handling code. test->Run(); } if (test != nullptr) { // Deletes the test object. impl->os_stack_trace_getter()->UponLeavingGTest(); internal::HandleExceptionsInMethodIfSupported( test, &Test::DeleteSelf_, "the test fixture's destructor"); } result_.set_start_timestamp(start); result_.set_elapsed_time(internal::GetTimeInMillis() - start); // Notifies the unit test event listener that a test has just finished. repeater->OnTestEnd(*this); // Tells UnitTest to stop associating assertion results to this // test. impl->set_current_test_info(nullptr); } // class TestSuite // Gets the number of successful tests in this test suite. int TestSuite::successful_test_count() const { return CountIf(test_info_list_, TestPassed); } // Gets the number of successful tests in this test suite. int TestSuite::skipped_test_count() const { return CountIf(test_info_list_, TestSkipped); } // Gets the number of failed tests in this test suite. int TestSuite::failed_test_count() const { return CountIf(test_info_list_, TestFailed); } // Gets the number of disabled tests that will be reported in the XML report. int TestSuite::reportable_disabled_test_count() const { return CountIf(test_info_list_, TestReportableDisabled); } // Gets the number of disabled tests in this test suite. int TestSuite::disabled_test_count() const { return CountIf(test_info_list_, TestDisabled); } // Gets the number of tests to be printed in the XML report. int TestSuite::reportable_test_count() const { return CountIf(test_info_list_, TestReportable); } // Get the number of tests in this test suite that should run. int TestSuite::test_to_run_count() const { return CountIf(test_info_list_, ShouldRunTest); } // Gets the number of all tests. int TestSuite::total_test_count() const { return static_cast(test_info_list_.size()); } // Creates a TestSuite with the given name. // // Arguments: // // name: name of the test suite // a_type_param: the name of the test suite's type parameter, or NULL if // this is not a typed or a type-parameterized test suite. // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite TestSuite::TestSuite(const char* a_name, const char* a_type_param, internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc) : name_(a_name), type_param_(a_type_param ? new std::string(a_type_param) : nullptr), set_up_tc_(set_up_tc), tear_down_tc_(tear_down_tc), should_run_(false), start_timestamp_(0), elapsed_time_(0) {} // Destructor of TestSuite. TestSuite::~TestSuite() { // Deletes every Test in the collection. ForEach(test_info_list_, internal::Delete); } // Returns the i-th test among all the tests. i can range from 0 to // total_test_count() - 1. If i is not in that range, returns NULL. const TestInfo* TestSuite::GetTestInfo(int i) const { const int index = GetElementOr(test_indices_, i, -1); return index < 0 ? nullptr : test_info_list_[static_cast(index)]; } // Returns the i-th test among all the tests. i can range from 0 to // total_test_count() - 1. If i is not in that range, returns NULL. TestInfo* TestSuite::GetMutableTestInfo(int i) { const int index = GetElementOr(test_indices_, i, -1); return index < 0 ? nullptr : test_info_list_[static_cast(index)]; } // Adds a test to this test suite. Will delete the test upon // destruction of the TestSuite object. void TestSuite::AddTestInfo(TestInfo* test_info) { test_info_list_.push_back(test_info); test_indices_.push_back(static_cast(test_indices_.size())); } // Runs every test in this TestSuite. void TestSuite::Run() { if (!should_run_) return; internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); impl->set_current_test_suite(this); TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); // Call both legacy and the new API repeater->OnTestSuiteStart(*this); // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI repeater->OnTestCaseStart(*this); #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI impl->os_stack_trace_getter()->UponLeavingGTest(); internal::HandleExceptionsInMethodIfSupported( this, &TestSuite::RunSetUpTestSuite, "SetUpTestSuite()"); start_timestamp_ = internal::GetTimeInMillis(); for (int i = 0; i < total_test_count(); i++) { GetMutableTestInfo(i)->Run(); } elapsed_time_ = internal::GetTimeInMillis() - start_timestamp_; impl->os_stack_trace_getter()->UponLeavingGTest(); internal::HandleExceptionsInMethodIfSupported( this, &TestSuite::RunTearDownTestSuite, "TearDownTestSuite()"); // Call both legacy and the new API repeater->OnTestSuiteEnd(*this); // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI repeater->OnTestCaseEnd(*this); #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI impl->set_current_test_suite(nullptr); } // Clears the results of all tests in this test suite. void TestSuite::ClearResult() { ad_hoc_test_result_.Clear(); ForEach(test_info_list_, TestInfo::ClearTestResult); } // Shuffles the tests in this test suite. void TestSuite::ShuffleTests(internal::Random* random) { Shuffle(random, &test_indices_); } // Restores the test order to before the first shuffle. void TestSuite::UnshuffleTests() { for (size_t i = 0; i < test_indices_.size(); i++) { test_indices_[i] = static_cast(i); } } // Formats a countable noun. Depending on its quantity, either the // singular form or the plural form is used. e.g. // // FormatCountableNoun(1, "formula", "formuli") returns "1 formula". // FormatCountableNoun(5, "book", "books") returns "5 books". static std::string FormatCountableNoun(int count, const char * singular_form, const char * plural_form) { return internal::StreamableToString(count) + " " + (count == 1 ? singular_form : plural_form); } // Formats the count of tests. static std::string FormatTestCount(int test_count) { return FormatCountableNoun(test_count, "test", "tests"); } // Formats the count of test suites. static std::string FormatTestSuiteCount(int test_suite_count) { return FormatCountableNoun(test_suite_count, "test suite", "test suites"); } // Converts a TestPartResult::Type enum to human-friendly string // representation. Both kNonFatalFailure and kFatalFailure are translated // to "Failure", as the user usually doesn't care about the difference // between the two when viewing the test result. static const char * TestPartResultTypeToString(TestPartResult::Type type) { switch (type) { case TestPartResult::kSkip: return "Skipped"; case TestPartResult::kSuccess: return "Success"; case TestPartResult::kNonFatalFailure: case TestPartResult::kFatalFailure: #ifdef _MSC_VER return "error: "; #else return "Failure\n"; #endif default: return "Unknown result type"; } } namespace internal { // Prints a TestPartResult to an std::string. static std::string PrintTestPartResultToString( const TestPartResult& test_part_result) { return (Message() << internal::FormatFileLocation(test_part_result.file_name(), test_part_result.line_number()) << " " << TestPartResultTypeToString(test_part_result.type()) << test_part_result.message()).GetString(); } // Prints a TestPartResult. static void PrintTestPartResult(const TestPartResult& test_part_result) { const std::string& result = PrintTestPartResultToString(test_part_result); printf("%s\n", result.c_str()); fflush(stdout); // If the test program runs in Visual Studio or a debugger, the // following statements add the test part result message to the Output // window such that the user can double-click on it to jump to the // corresponding source code location; otherwise they do nothing. #if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE // We don't call OutputDebugString*() on Windows Mobile, as printing // to stdout is done by OutputDebugString() there already - we don't // want the same message printed twice. ::OutputDebugStringA(result.c_str()); ::OutputDebugStringA("\n"); #endif } // class PrettyUnitTestResultPrinter #if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE && \ !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT && !GTEST_OS_WINDOWS_MINGW // Returns the character attribute for the given color. static WORD GetColorAttribute(GTestColor color) { switch (color) { case COLOR_RED: return FOREGROUND_RED; case COLOR_GREEN: return FOREGROUND_GREEN; case COLOR_YELLOW: return FOREGROUND_RED | FOREGROUND_GREEN; default: return 0; } } static int GetBitOffset(WORD color_mask) { if (color_mask == 0) return 0; int bitOffset = 0; while ((color_mask & 1) == 0) { color_mask >>= 1; ++bitOffset; } return bitOffset; } static WORD GetNewColor(GTestColor color, WORD old_color_attrs) { // Let's reuse the BG static const WORD background_mask = BACKGROUND_BLUE | BACKGROUND_GREEN | BACKGROUND_RED | BACKGROUND_INTENSITY; static const WORD foreground_mask = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_INTENSITY; const WORD existing_bg = old_color_attrs & background_mask; WORD new_color = GetColorAttribute(color) | existing_bg | FOREGROUND_INTENSITY; static const int bg_bitOffset = GetBitOffset(background_mask); static const int fg_bitOffset = GetBitOffset(foreground_mask); if (((new_color & background_mask) >> bg_bitOffset) == ((new_color & foreground_mask) >> fg_bitOffset)) { new_color ^= FOREGROUND_INTENSITY; // invert intensity } return new_color; } #else // Returns the ANSI color code for the given color. COLOR_DEFAULT is // an invalid input. static const char* GetAnsiColorCode(GTestColor color) { switch (color) { case COLOR_RED: return "1"; case COLOR_GREEN: return "2"; case COLOR_YELLOW: return "3"; default: return nullptr; } } #endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE // Returns true if and only if Google Test should use colors in the output. bool ShouldUseColor(bool stdout_is_tty) { const char* const gtest_color = GTEST_FLAG(color).c_str(); if (String::CaseInsensitiveCStringEquals(gtest_color, "auto")) { #if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MINGW // On Windows the TERM variable is usually not set, but the // console there does support colors. return stdout_is_tty; #else // On non-Windows platforms, we rely on the TERM variable. const char* const term = posix::GetEnv("TERM"); const bool term_supports_color = String::CStringEquals(term, "xterm") || String::CStringEquals(term, "xterm-color") || String::CStringEquals(term, "xterm-256color") || String::CStringEquals(term, "screen") || String::CStringEquals(term, "screen-256color") || String::CStringEquals(term, "tmux") || String::CStringEquals(term, "tmux-256color") || String::CStringEquals(term, "rxvt-unicode") || String::CStringEquals(term, "rxvt-unicode-256color") || String::CStringEquals(term, "linux") || String::CStringEquals(term, "cygwin"); return stdout_is_tty && term_supports_color; #endif // GTEST_OS_WINDOWS } return String::CaseInsensitiveCStringEquals(gtest_color, "yes") || String::CaseInsensitiveCStringEquals(gtest_color, "true") || String::CaseInsensitiveCStringEquals(gtest_color, "t") || String::CStringEquals(gtest_color, "1"); // We take "yes", "true", "t", and "1" as meaning "yes". If the // value is neither one of these nor "auto", we treat it as "no" to // be conservative. } // Helpers for printing colored strings to stdout. Note that on Windows, we // cannot simply emit special characters and have the terminal change colors. // This routine must actually emit the characters rather than return a string // that would be colored when printed, as can be done on Linux. void ColoredPrintf(GTestColor color, const char* fmt, ...) { va_list args; va_start(args, fmt); #if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_ZOS || GTEST_OS_IOS || \ GTEST_OS_WINDOWS_PHONE || GTEST_OS_WINDOWS_RT || defined(ESP_PLATFORM) const bool use_color = AlwaysFalse(); #else static const bool in_color_mode = ShouldUseColor(posix::IsATTY(posix::FileNo(stdout)) != 0); const bool use_color = in_color_mode && (color != COLOR_DEFAULT); #endif // GTEST_OS_WINDOWS_MOBILE || GTEST_OS_ZOS if (!use_color) { vprintf(fmt, args); va_end(args); return; } #if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE && \ !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT && !GTEST_OS_WINDOWS_MINGW const HANDLE stdout_handle = GetStdHandle(STD_OUTPUT_HANDLE); // Gets the current text color. CONSOLE_SCREEN_BUFFER_INFO buffer_info; GetConsoleScreenBufferInfo(stdout_handle, &buffer_info); const WORD old_color_attrs = buffer_info.wAttributes; const WORD new_color = GetNewColor(color, old_color_attrs); // We need to flush the stream buffers into the console before each // SetConsoleTextAttribute call lest it affect the text that is already // printed but has not yet reached the console. fflush(stdout); SetConsoleTextAttribute(stdout_handle, new_color); vprintf(fmt, args); fflush(stdout); // Restores the text color. SetConsoleTextAttribute(stdout_handle, old_color_attrs); #else printf("\033[0;3%sm", GetAnsiColorCode(color)); vprintf(fmt, args); printf("\033[m"); // Resets the terminal to default. #endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE va_end(args); } // Text printed in Google Test's text output and --gtest_list_tests // output to label the type parameter and value parameter for a test. static const char kTypeParamLabel[] = "TypeParam"; static const char kValueParamLabel[] = "GetParam()"; static void PrintFullTestCommentIfPresent(const TestInfo& test_info) { const char* const type_param = test_info.type_param(); const char* const value_param = test_info.value_param(); if (type_param != nullptr || value_param != nullptr) { printf(", where "); if (type_param != nullptr) { printf("%s = %s", kTypeParamLabel, type_param); if (value_param != nullptr) printf(" and "); } if (value_param != nullptr) { printf("%s = %s", kValueParamLabel, value_param); } } } // This class implements the TestEventListener interface. // // Class PrettyUnitTestResultPrinter is copyable. class PrettyUnitTestResultPrinter : public TestEventListener { public: PrettyUnitTestResultPrinter() {} static void PrintTestName(const char* test_suite, const char* test) { printf("%s.%s", test_suite, test); } // The following methods override what's in the TestEventListener class. void OnTestProgramStart(const UnitTest& /*unit_test*/) override {} void OnTestIterationStart(const UnitTest& unit_test, int iteration) override; void OnEnvironmentsSetUpStart(const UnitTest& unit_test) override; void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override {} #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestCaseStart(const TestCase& test_case) override; #else void OnTestSuiteStart(const TestSuite& test_suite) override; #endif // OnTestCaseStart void OnTestStart(const TestInfo& test_info) override; void OnTestPartResult(const TestPartResult& result) override; void OnTestEnd(const TestInfo& test_info) override; #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestCaseEnd(const TestCase& test_case) override; #else void OnTestSuiteEnd(const TestSuite& test_suite) override; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnEnvironmentsTearDownStart(const UnitTest& unit_test) override; void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override {} void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; void OnTestProgramEnd(const UnitTest& /*unit_test*/) override {} private: static void PrintFailedTests(const UnitTest& unit_test); static void PrintSkippedTests(const UnitTest& unit_test); }; // Fired before each iteration of tests starts. void PrettyUnitTestResultPrinter::OnTestIterationStart( const UnitTest& unit_test, int iteration) { if (GTEST_FLAG(repeat) != 1) printf("\nRepeating all tests (iteration %d) . . .\n\n", iteration + 1); const char* const filter = GTEST_FLAG(filter).c_str(); // Prints the filter if it's not *. This reminds the user that some // tests may be skipped. if (!String::CStringEquals(filter, kUniversalFilter)) { ColoredPrintf(COLOR_YELLOW, "Note: %s filter = %s\n", GTEST_NAME_, filter); } if (internal::ShouldShard(kTestTotalShards, kTestShardIndex, false)) { const Int32 shard_index = Int32FromEnvOrDie(kTestShardIndex, -1); ColoredPrintf(COLOR_YELLOW, "Note: This is test shard %d of %s.\n", static_cast(shard_index) + 1, internal::posix::GetEnv(kTestTotalShards)); } if (GTEST_FLAG(shuffle)) { ColoredPrintf(COLOR_YELLOW, "Note: Randomizing tests' orders with a seed of %d .\n", unit_test.random_seed()); } ColoredPrintf(COLOR_GREEN, "[==========] "); printf("Running %s from %s.\n", FormatTestCount(unit_test.test_to_run_count()).c_str(), FormatTestSuiteCount(unit_test.test_suite_to_run_count()).c_str()); fflush(stdout); } void PrettyUnitTestResultPrinter::OnEnvironmentsSetUpStart( const UnitTest& /*unit_test*/) { ColoredPrintf(COLOR_GREEN, "[----------] "); printf("Global test environment set-up.\n"); fflush(stdout); } #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void PrettyUnitTestResultPrinter::OnTestCaseStart(const TestCase& test_case) { const std::string counts = FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); ColoredPrintf(COLOR_GREEN, "[----------] "); printf("%s from %s", counts.c_str(), test_case.name()); if (test_case.type_param() == nullptr) { printf("\n"); } else { printf(", where %s = %s\n", kTypeParamLabel, test_case.type_param()); } fflush(stdout); } #else void PrettyUnitTestResultPrinter::OnTestSuiteStart( const TestSuite& test_suite) { const std::string counts = FormatCountableNoun(test_suite.test_to_run_count(), "test", "tests"); ColoredPrintf(COLOR_GREEN, "[----------] "); printf("%s from %s", counts.c_str(), test_suite.name()); if (test_suite.type_param() == nullptr) { printf("\n"); } else { printf(", where %s = %s\n", kTypeParamLabel, test_suite.type_param()); } fflush(stdout); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void PrettyUnitTestResultPrinter::OnTestStart(const TestInfo& test_info) { ColoredPrintf(COLOR_GREEN, "[ RUN ] "); PrintTestName(test_info.test_suite_name(), test_info.name()); printf("\n"); fflush(stdout); } // Called after an assertion failure. void PrettyUnitTestResultPrinter::OnTestPartResult( const TestPartResult& result) { switch (result.type()) { // If the test part succeeded, or was skipped, // we don't need to do anything. case TestPartResult::kSuccess: return; default: // Print failure message from the assertion // (e.g. expected this and got that). PrintTestPartResult(result); fflush(stdout); } } void PrettyUnitTestResultPrinter::OnTestEnd(const TestInfo& test_info) { if (test_info.result()->Passed()) { ColoredPrintf(COLOR_GREEN, "[ OK ] "); } else if (test_info.result()->Skipped()) { ColoredPrintf(COLOR_GREEN, "[ SKIPPED ] "); } else { ColoredPrintf(COLOR_RED, "[ FAILED ] "); } PrintTestName(test_info.test_suite_name(), test_info.name()); if (test_info.result()->Failed()) PrintFullTestCommentIfPresent(test_info); if (GTEST_FLAG(print_time)) { printf(" (%s ms)\n", internal::StreamableToString( test_info.result()->elapsed_time()).c_str()); } else { printf("\n"); } fflush(stdout); } #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void PrettyUnitTestResultPrinter::OnTestCaseEnd(const TestCase& test_case) { if (!GTEST_FLAG(print_time)) return; const std::string counts = FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); ColoredPrintf(COLOR_GREEN, "[----------] "); printf("%s from %s (%s ms total)\n\n", counts.c_str(), test_case.name(), internal::StreamableToString(test_case.elapsed_time()).c_str()); fflush(stdout); } #else void PrettyUnitTestResultPrinter::OnTestSuiteEnd(const TestSuite& test_suite) { if (!GTEST_FLAG(print_time)) return; const std::string counts = FormatCountableNoun(test_suite.test_to_run_count(), "test", "tests"); ColoredPrintf(COLOR_GREEN, "[----------] "); printf("%s from %s (%s ms total)\n\n", counts.c_str(), test_suite.name(), internal::StreamableToString(test_suite.elapsed_time()).c_str()); fflush(stdout); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void PrettyUnitTestResultPrinter::OnEnvironmentsTearDownStart( const UnitTest& /*unit_test*/) { ColoredPrintf(COLOR_GREEN, "[----------] "); printf("Global test environment tear-down\n"); fflush(stdout); } // Internal helper for printing the list of failed tests. void PrettyUnitTestResultPrinter::PrintFailedTests(const UnitTest& unit_test) { const int failed_test_count = unit_test.failed_test_count(); if (failed_test_count == 0) { return; } for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { const TestSuite& test_suite = *unit_test.GetTestSuite(i); if (!test_suite.should_run() || (test_suite.failed_test_count() == 0)) { continue; } for (int j = 0; j < test_suite.total_test_count(); ++j) { const TestInfo& test_info = *test_suite.GetTestInfo(j); if (!test_info.should_run() || !test_info.result()->Failed()) { continue; } ColoredPrintf(COLOR_RED, "[ FAILED ] "); printf("%s.%s", test_suite.name(), test_info.name()); PrintFullTestCommentIfPresent(test_info); printf("\n"); } } } // Internal helper for printing the list of skipped tests. void PrettyUnitTestResultPrinter::PrintSkippedTests(const UnitTest& unit_test) { const int skipped_test_count = unit_test.skipped_test_count(); if (skipped_test_count == 0) { return; } for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { const TestSuite& test_suite = *unit_test.GetTestSuite(i); if (!test_suite.should_run() || (test_suite.skipped_test_count() == 0)) { continue; } for (int j = 0; j < test_suite.total_test_count(); ++j) { const TestInfo& test_info = *test_suite.GetTestInfo(j); if (!test_info.should_run() || !test_info.result()->Skipped()) { continue; } ColoredPrintf(COLOR_GREEN, "[ SKIPPED ] "); printf("%s.%s", test_suite.name(), test_info.name()); printf("\n"); } } } void PrettyUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, int /*iteration*/) { ColoredPrintf(COLOR_GREEN, "[==========] "); printf("%s from %s ran.", FormatTestCount(unit_test.test_to_run_count()).c_str(), FormatTestSuiteCount(unit_test.test_suite_to_run_count()).c_str()); if (GTEST_FLAG(print_time)) { printf(" (%s ms total)", internal::StreamableToString(unit_test.elapsed_time()).c_str()); } printf("\n"); ColoredPrintf(COLOR_GREEN, "[ PASSED ] "); printf("%s.\n", FormatTestCount(unit_test.successful_test_count()).c_str()); const int skipped_test_count = unit_test.skipped_test_count(); if (skipped_test_count > 0) { ColoredPrintf(COLOR_GREEN, "[ SKIPPED ] "); if (GTEST_FLAG(print_skipped)) { printf("%s, listed below:\n", FormatTestCount(skipped_test_count).c_str()); PrintSkippedTests(unit_test); } else { printf("%s.\n", FormatTestCount(skipped_test_count).c_str()); } } int num_failures = unit_test.failed_test_count(); if (!unit_test.Passed()) { const int failed_test_count = unit_test.failed_test_count(); ColoredPrintf(COLOR_RED, "[ FAILED ] "); printf("%s, listed below:\n", FormatTestCount(failed_test_count).c_str()); PrintFailedTests(unit_test); printf("\n%2d FAILED %s\n", num_failures, num_failures == 1 ? "TEST" : "TESTS"); } int num_disabled = unit_test.reportable_disabled_test_count(); if (num_disabled && !GTEST_FLAG(also_run_disabled_tests)) { if (!num_failures) { printf("\n"); // Add a spacer if no FAILURE banner is displayed. } ColoredPrintf(COLOR_YELLOW, " YOU HAVE %d DISABLED %s\n\n", num_disabled, num_disabled == 1 ? "TEST" : "TESTS"); } // Ensure that Google Test output is printed before, e.g., heapchecker output. fflush(stdout); } // End PrettyUnitTestResultPrinter // class TestEventRepeater // // This class forwards events to other event listeners. class TestEventRepeater : public TestEventListener { public: TestEventRepeater() : forwarding_enabled_(true) {} ~TestEventRepeater() override; void Append(TestEventListener *listener); TestEventListener* Release(TestEventListener* listener); // Controls whether events will be forwarded to listeners_. Set to false // in death test child processes. bool forwarding_enabled() const { return forwarding_enabled_; } void set_forwarding_enabled(bool enable) { forwarding_enabled_ = enable; } void OnTestProgramStart(const UnitTest& unit_test) override; void OnTestIterationStart(const UnitTest& unit_test, int iteration) override; void OnEnvironmentsSetUpStart(const UnitTest& unit_test) override; void OnEnvironmentsSetUpEnd(const UnitTest& unit_test) override; // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestCaseStart(const TestSuite& parameter) override; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestSuiteStart(const TestSuite& parameter) override; void OnTestStart(const TestInfo& test_info) override; void OnTestPartResult(const TestPartResult& result) override; void OnTestEnd(const TestInfo& test_info) override; // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestCaseEnd(const TestCase& parameter) override; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestSuiteEnd(const TestSuite& parameter) override; void OnEnvironmentsTearDownStart(const UnitTest& unit_test) override; void OnEnvironmentsTearDownEnd(const UnitTest& unit_test) override; void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; void OnTestProgramEnd(const UnitTest& unit_test) override; private: // Controls whether events will be forwarded to listeners_. Set to false // in death test child processes. bool forwarding_enabled_; // The list of listeners that receive events. std::vector listeners_; GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventRepeater); }; TestEventRepeater::~TestEventRepeater() { ForEach(listeners_, Delete); } void TestEventRepeater::Append(TestEventListener *listener) { listeners_.push_back(listener); } TestEventListener* TestEventRepeater::Release(TestEventListener *listener) { for (size_t i = 0; i < listeners_.size(); ++i) { if (listeners_[i] == listener) { listeners_.erase(listeners_.begin() + static_cast(i)); return listener; } } return nullptr; } // Since most methods are very similar, use macros to reduce boilerplate. // This defines a member that forwards the call to all listeners. #define GTEST_REPEATER_METHOD_(Name, Type) \ void TestEventRepeater::Name(const Type& parameter) { \ if (forwarding_enabled_) { \ for (size_t i = 0; i < listeners_.size(); i++) { \ listeners_[i]->Name(parameter); \ } \ } \ } // This defines a member that forwards the call to all listeners in reverse // order. #define GTEST_REVERSE_REPEATER_METHOD_(Name, Type) \ void TestEventRepeater::Name(const Type& parameter) { \ if (forwarding_enabled_) { \ for (size_t i = listeners_.size(); i != 0; i--) { \ listeners_[i - 1]->Name(parameter); \ } \ } \ } GTEST_REPEATER_METHOD_(OnTestProgramStart, UnitTest) GTEST_REPEATER_METHOD_(OnEnvironmentsSetUpStart, UnitTest) // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ GTEST_REPEATER_METHOD_(OnTestCaseStart, TestSuite) #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ GTEST_REPEATER_METHOD_(OnTestSuiteStart, TestSuite) GTEST_REPEATER_METHOD_(OnTestStart, TestInfo) GTEST_REPEATER_METHOD_(OnTestPartResult, TestPartResult) GTEST_REPEATER_METHOD_(OnEnvironmentsTearDownStart, UnitTest) GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsSetUpEnd, UnitTest) GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsTearDownEnd, UnitTest) GTEST_REVERSE_REPEATER_METHOD_(OnTestEnd, TestInfo) // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ GTEST_REVERSE_REPEATER_METHOD_(OnTestCaseEnd, TestSuite) #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ GTEST_REVERSE_REPEATER_METHOD_(OnTestSuiteEnd, TestSuite) GTEST_REVERSE_REPEATER_METHOD_(OnTestProgramEnd, UnitTest) #undef GTEST_REPEATER_METHOD_ #undef GTEST_REVERSE_REPEATER_METHOD_ void TestEventRepeater::OnTestIterationStart(const UnitTest& unit_test, int iteration) { if (forwarding_enabled_) { for (size_t i = 0; i < listeners_.size(); i++) { listeners_[i]->OnTestIterationStart(unit_test, iteration); } } } void TestEventRepeater::OnTestIterationEnd(const UnitTest& unit_test, int iteration) { if (forwarding_enabled_) { for (size_t i = listeners_.size(); i > 0; i--) { listeners_[i - 1]->OnTestIterationEnd(unit_test, iteration); } } } // End TestEventRepeater // This class generates an XML output file. class XmlUnitTestResultPrinter : public EmptyTestEventListener { public: explicit XmlUnitTestResultPrinter(const char* output_file); void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; void ListTestsMatchingFilter(const std::vector& test_suites); // Prints an XML summary of all unit tests. static void PrintXmlTestsList(std::ostream* stream, const std::vector& test_suites); private: // Is c a whitespace character that is normalized to a space character // when it appears in an XML attribute value? static bool IsNormalizableWhitespace(char c) { return c == 0x9 || c == 0xA || c == 0xD; } // May c appear in a well-formed XML document? static bool IsValidXmlCharacter(char c) { return IsNormalizableWhitespace(c) || c >= 0x20; } // Returns an XML-escaped copy of the input string str. If // is_attribute is true, the text is meant to appear as an attribute // value, and normalizable whitespace is preserved by replacing it // with character references. static std::string EscapeXml(const std::string& str, bool is_attribute); // Returns the given string with all characters invalid in XML removed. static std::string RemoveInvalidXmlCharacters(const std::string& str); // Convenience wrapper around EscapeXml when str is an attribute value. static std::string EscapeXmlAttribute(const std::string& str) { return EscapeXml(str, true); } // Convenience wrapper around EscapeXml when str is not an attribute value. static std::string EscapeXmlText(const char* str) { return EscapeXml(str, false); } // Verifies that the given attribute belongs to the given element and // streams the attribute as XML. static void OutputXmlAttribute(std::ostream* stream, const std::string& element_name, const std::string& name, const std::string& value); // Streams an XML CDATA section, escaping invalid CDATA sequences as needed. static void OutputXmlCDataSection(::std::ostream* stream, const char* data); // Streams an XML representation of a TestInfo object. static void OutputXmlTestInfo(::std::ostream* stream, const char* test_suite_name, const TestInfo& test_info); // Prints an XML representation of a TestSuite object static void PrintXmlTestSuite(::std::ostream* stream, const TestSuite& test_suite); // Prints an XML summary of unit_test to output stream out. static void PrintXmlUnitTest(::std::ostream* stream, const UnitTest& unit_test); // Produces a string representing the test properties in a result as space // delimited XML attributes based on the property key="value" pairs. // When the std::string is not empty, it includes a space at the beginning, // to delimit this attribute from prior attributes. static std::string TestPropertiesAsXmlAttributes(const TestResult& result); // Streams an XML representation of the test properties of a TestResult // object. static void OutputXmlTestProperties(std::ostream* stream, const TestResult& result); // The output file. const std::string output_file_; GTEST_DISALLOW_COPY_AND_ASSIGN_(XmlUnitTestResultPrinter); }; // Creates a new XmlUnitTestResultPrinter. XmlUnitTestResultPrinter::XmlUnitTestResultPrinter(const char* output_file) : output_file_(output_file) { if (output_file_.empty()) { GTEST_LOG_(FATAL) << "XML output file may not be null"; } } // Called after the unit test ends. void XmlUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, int /*iteration*/) { FILE* xmlout = OpenFileForWriting(output_file_); std::stringstream stream; PrintXmlUnitTest(&stream, unit_test); fprintf(xmlout, "%s", StringStreamToString(&stream).c_str()); fclose(xmlout); } void XmlUnitTestResultPrinter::ListTestsMatchingFilter( const std::vector& test_suites) { FILE* xmlout = OpenFileForWriting(output_file_); std::stringstream stream; PrintXmlTestsList(&stream, test_suites); fprintf(xmlout, "%s", StringStreamToString(&stream).c_str()); fclose(xmlout); } // Returns an XML-escaped copy of the input string str. If is_attribute // is true, the text is meant to appear as an attribute value, and // normalizable whitespace is preserved by replacing it with character // references. // // Invalid XML characters in str, if any, are stripped from the output. // It is expected that most, if not all, of the text processed by this // module will consist of ordinary English text. // If this module is ever modified to produce version 1.1 XML output, // most invalid characters can be retained using character references. std::string XmlUnitTestResultPrinter::EscapeXml( const std::string& str, bool is_attribute) { Message m; for (size_t i = 0; i < str.size(); ++i) { const char ch = str[i]; switch (ch) { case '<': m << "<"; break; case '>': m << ">"; break; case '&': m << "&"; break; case '\'': if (is_attribute) m << "'"; else m << '\''; break; case '"': if (is_attribute) m << """; else m << '"'; break; default: if (IsValidXmlCharacter(ch)) { if (is_attribute && IsNormalizableWhitespace(ch)) m << "&#x" << String::FormatByte(static_cast(ch)) << ";"; else m << ch; } break; } } return m.GetString(); } // Returns the given string with all characters invalid in XML removed. // Currently invalid characters are dropped from the string. An // alternative is to replace them with certain characters such as . or ?. std::string XmlUnitTestResultPrinter::RemoveInvalidXmlCharacters( const std::string& str) { std::string output; output.reserve(str.size()); for (std::string::const_iterator it = str.begin(); it != str.end(); ++it) if (IsValidXmlCharacter(*it)) output.push_back(*it); return output; } // The following routines generate an XML representation of a UnitTest // object. // GOOGLETEST_CM0009 DO NOT DELETE // // This is how Google Test concepts map to the DTD: // // <-- corresponds to a UnitTest object // <-- corresponds to a TestSuite object // <-- corresponds to a TestInfo object // ... // ... // ... // <-- individual assertion failures // // // // Formats the given time in milliseconds as seconds. std::string FormatTimeInMillisAsSeconds(TimeInMillis ms) { ::std::stringstream ss; ss << (static_cast(ms) * 1e-3); return ss.str(); } static bool PortableLocaltime(time_t seconds, struct tm* out) { #if defined(_MSC_VER) return localtime_s(out, &seconds) == 0; #elif defined(__MINGW32__) || defined(__MINGW64__) // MINGW provides neither localtime_r nor localtime_s, but uses // Windows' localtime(), which has a thread-local tm buffer. struct tm* tm_ptr = localtime(&seconds); // NOLINT if (tm_ptr == nullptr) return false; *out = *tm_ptr; return true; #else return localtime_r(&seconds, out) != nullptr; #endif } // Converts the given epoch time in milliseconds to a date string in the ISO // 8601 format, without the timezone information. std::string FormatEpochTimeInMillisAsIso8601(TimeInMillis ms) { struct tm time_struct; if (!PortableLocaltime(static_cast(ms / 1000), &time_struct)) return ""; // YYYY-MM-DDThh:mm:ss return StreamableToString(time_struct.tm_year + 1900) + "-" + String::FormatIntWidth2(time_struct.tm_mon + 1) + "-" + String::FormatIntWidth2(time_struct.tm_mday) + "T" + String::FormatIntWidth2(time_struct.tm_hour) + ":" + String::FormatIntWidth2(time_struct.tm_min) + ":" + String::FormatIntWidth2(time_struct.tm_sec); } // Streams an XML CDATA section, escaping invalid CDATA sequences as needed. void XmlUnitTestResultPrinter::OutputXmlCDataSection(::std::ostream* stream, const char* data) { const char* segment = data; *stream << ""); if (next_segment != nullptr) { stream->write( segment, static_cast(next_segment - segment)); *stream << "]]>]]>"); } else { *stream << segment; break; } } *stream << "]]>"; } void XmlUnitTestResultPrinter::OutputXmlAttribute( std::ostream* stream, const std::string& element_name, const std::string& name, const std::string& value) { const std::vector& allowed_names = GetReservedOutputAttributesForElement(element_name); GTEST_CHECK_(std::find(allowed_names.begin(), allowed_names.end(), name) != allowed_names.end()) << "Attribute " << name << " is not allowed for element <" << element_name << ">."; *stream << " " << name << "=\"" << EscapeXmlAttribute(value) << "\""; } // Prints an XML representation of a TestInfo object. void XmlUnitTestResultPrinter::OutputXmlTestInfo(::std::ostream* stream, const char* test_suite_name, const TestInfo& test_info) { const TestResult& result = *test_info.result(); const std::string kTestsuite = "testcase"; if (test_info.is_in_another_shard()) { return; } *stream << " \n"; return; } OutputXmlAttribute(stream, kTestsuite, "status", test_info.should_run() ? "run" : "notrun"); OutputXmlAttribute(stream, kTestsuite, "result", test_info.should_run() ? (result.Skipped() ? "skipped" : "completed") : "suppressed"); OutputXmlAttribute(stream, kTestsuite, "time", FormatTimeInMillisAsSeconds(result.elapsed_time())); OutputXmlAttribute( stream, kTestsuite, "timestamp", FormatEpochTimeInMillisAsIso8601(result.start_timestamp())); OutputXmlAttribute(stream, kTestsuite, "classname", test_suite_name); int failures = 0; for (int i = 0; i < result.total_part_count(); ++i) { const TestPartResult& part = result.GetTestPartResult(i); if (part.failed()) { if (++failures == 1) { *stream << ">\n"; } const std::string location = internal::FormatCompilerIndependentFileLocation(part.file_name(), part.line_number()); const std::string summary = location + "\n" + part.summary(); *stream << " "; const std::string detail = location + "\n" + part.message(); OutputXmlCDataSection(stream, RemoveInvalidXmlCharacters(detail).c_str()); *stream << "\n"; } } if (failures == 0 && result.test_property_count() == 0) { *stream << " />\n"; } else { if (failures == 0) { *stream << ">\n"; } OutputXmlTestProperties(stream, result); *stream << " \n"; } } // Prints an XML representation of a TestSuite object void XmlUnitTestResultPrinter::PrintXmlTestSuite(std::ostream* stream, const TestSuite& test_suite) { const std::string kTestsuite = "testsuite"; *stream << " <" << kTestsuite; OutputXmlAttribute(stream, kTestsuite, "name", test_suite.name()); OutputXmlAttribute(stream, kTestsuite, "tests", StreamableToString(test_suite.reportable_test_count())); if (!GTEST_FLAG(list_tests)) { OutputXmlAttribute(stream, kTestsuite, "failures", StreamableToString(test_suite.failed_test_count())); OutputXmlAttribute( stream, kTestsuite, "disabled", StreamableToString(test_suite.reportable_disabled_test_count())); OutputXmlAttribute(stream, kTestsuite, "errors", "0"); OutputXmlAttribute(stream, kTestsuite, "time", FormatTimeInMillisAsSeconds(test_suite.elapsed_time())); OutputXmlAttribute( stream, kTestsuite, "timestamp", FormatEpochTimeInMillisAsIso8601(test_suite.start_timestamp())); *stream << TestPropertiesAsXmlAttributes(test_suite.ad_hoc_test_result()); } *stream << ">\n"; for (int i = 0; i < test_suite.total_test_count(); ++i) { if (test_suite.GetTestInfo(i)->is_reportable()) OutputXmlTestInfo(stream, test_suite.name(), *test_suite.GetTestInfo(i)); } *stream << " \n"; } // Prints an XML summary of unit_test to output stream out. void XmlUnitTestResultPrinter::PrintXmlUnitTest(std::ostream* stream, const UnitTest& unit_test) { const std::string kTestsuites = "testsuites"; *stream << "\n"; *stream << "<" << kTestsuites; OutputXmlAttribute(stream, kTestsuites, "tests", StreamableToString(unit_test.reportable_test_count())); OutputXmlAttribute(stream, kTestsuites, "failures", StreamableToString(unit_test.failed_test_count())); OutputXmlAttribute( stream, kTestsuites, "disabled", StreamableToString(unit_test.reportable_disabled_test_count())); OutputXmlAttribute(stream, kTestsuites, "errors", "0"); OutputXmlAttribute(stream, kTestsuites, "time", FormatTimeInMillisAsSeconds(unit_test.elapsed_time())); OutputXmlAttribute( stream, kTestsuites, "timestamp", FormatEpochTimeInMillisAsIso8601(unit_test.start_timestamp())); if (GTEST_FLAG(shuffle)) { OutputXmlAttribute(stream, kTestsuites, "random_seed", StreamableToString(unit_test.random_seed())); } *stream << TestPropertiesAsXmlAttributes(unit_test.ad_hoc_test_result()); OutputXmlAttribute(stream, kTestsuites, "name", "AllTests"); *stream << ">\n"; for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { if (unit_test.GetTestSuite(i)->reportable_test_count() > 0) PrintXmlTestSuite(stream, *unit_test.GetTestSuite(i)); } *stream << "\n"; } void XmlUnitTestResultPrinter::PrintXmlTestsList( std::ostream* stream, const std::vector& test_suites) { const std::string kTestsuites = "testsuites"; *stream << "\n"; *stream << "<" << kTestsuites; int total_tests = 0; for (auto test_suite : test_suites) { total_tests += test_suite->total_test_count(); } OutputXmlAttribute(stream, kTestsuites, "tests", StreamableToString(total_tests)); OutputXmlAttribute(stream, kTestsuites, "name", "AllTests"); *stream << ">\n"; for (auto test_suite : test_suites) { PrintXmlTestSuite(stream, *test_suite); } *stream << "\n"; } // Produces a string representing the test properties in a result as space // delimited XML attributes based on the property key="value" pairs. std::string XmlUnitTestResultPrinter::TestPropertiesAsXmlAttributes( const TestResult& result) { Message attributes; for (int i = 0; i < result.test_property_count(); ++i) { const TestProperty& property = result.GetTestProperty(i); attributes << " " << property.key() << "=" << "\"" << EscapeXmlAttribute(property.value()) << "\""; } return attributes.GetString(); } void XmlUnitTestResultPrinter::OutputXmlTestProperties( std::ostream* stream, const TestResult& result) { const std::string kProperties = "properties"; const std::string kProperty = "property"; if (result.test_property_count() <= 0) { return; } *stream << "<" << kProperties << ">\n"; for (int i = 0; i < result.test_property_count(); ++i) { const TestProperty& property = result.GetTestProperty(i); *stream << "<" << kProperty; *stream << " name=\"" << EscapeXmlAttribute(property.key()) << "\""; *stream << " value=\"" << EscapeXmlAttribute(property.value()) << "\""; *stream << "/>\n"; } *stream << "\n"; } // End XmlUnitTestResultPrinter // This class generates an JSON output file. class JsonUnitTestResultPrinter : public EmptyTestEventListener { public: explicit JsonUnitTestResultPrinter(const char* output_file); void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; // Prints an JSON summary of all unit tests. static void PrintJsonTestList(::std::ostream* stream, const std::vector& test_suites); private: // Returns an JSON-escaped copy of the input string str. static std::string EscapeJson(const std::string& str); //// Verifies that the given attribute belongs to the given element and //// streams the attribute as JSON. static void OutputJsonKey(std::ostream* stream, const std::string& element_name, const std::string& name, const std::string& value, const std::string& indent, bool comma = true); static void OutputJsonKey(std::ostream* stream, const std::string& element_name, const std::string& name, int value, const std::string& indent, bool comma = true); // Streams a JSON representation of a TestInfo object. static void OutputJsonTestInfo(::std::ostream* stream, const char* test_suite_name, const TestInfo& test_info); // Prints a JSON representation of a TestSuite object static void PrintJsonTestSuite(::std::ostream* stream, const TestSuite& test_suite); // Prints a JSON summary of unit_test to output stream out. static void PrintJsonUnitTest(::std::ostream* stream, const UnitTest& unit_test); // Produces a string representing the test properties in a result as // a JSON dictionary. static std::string TestPropertiesAsJson(const TestResult& result, const std::string& indent); // The output file. const std::string output_file_; GTEST_DISALLOW_COPY_AND_ASSIGN_(JsonUnitTestResultPrinter); }; // Creates a new JsonUnitTestResultPrinter. JsonUnitTestResultPrinter::JsonUnitTestResultPrinter(const char* output_file) : output_file_(output_file) { if (output_file_.empty()) { GTEST_LOG_(FATAL) << "JSON output file may not be null"; } } void JsonUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, int /*iteration*/) { FILE* jsonout = OpenFileForWriting(output_file_); std::stringstream stream; PrintJsonUnitTest(&stream, unit_test); fprintf(jsonout, "%s", StringStreamToString(&stream).c_str()); fclose(jsonout); } // Returns an JSON-escaped copy of the input string str. std::string JsonUnitTestResultPrinter::EscapeJson(const std::string& str) { Message m; for (size_t i = 0; i < str.size(); ++i) { const char ch = str[i]; switch (ch) { case '\\': case '"': case '/': m << '\\' << ch; break; case '\b': m << "\\b"; break; case '\t': m << "\\t"; break; case '\n': m << "\\n"; break; case '\f': m << "\\f"; break; case '\r': m << "\\r"; break; default: if (ch < ' ') { m << "\\u00" << String::FormatByte(static_cast(ch)); } else { m << ch; } break; } } return m.GetString(); } // The following routines generate an JSON representation of a UnitTest // object. // Formats the given time in milliseconds as seconds. static std::string FormatTimeInMillisAsDuration(TimeInMillis ms) { ::std::stringstream ss; ss << (static_cast(ms) * 1e-3) << "s"; return ss.str(); } // Converts the given epoch time in milliseconds to a date string in the // RFC3339 format, without the timezone information. static std::string FormatEpochTimeInMillisAsRFC3339(TimeInMillis ms) { struct tm time_struct; if (!PortableLocaltime(static_cast(ms / 1000), &time_struct)) return ""; // YYYY-MM-DDThh:mm:ss return StreamableToString(time_struct.tm_year + 1900) + "-" + String::FormatIntWidth2(time_struct.tm_mon + 1) + "-" + String::FormatIntWidth2(time_struct.tm_mday) + "T" + String::FormatIntWidth2(time_struct.tm_hour) + ":" + String::FormatIntWidth2(time_struct.tm_min) + ":" + String::FormatIntWidth2(time_struct.tm_sec) + "Z"; } static inline std::string Indent(size_t width) { return std::string(width, ' '); } void JsonUnitTestResultPrinter::OutputJsonKey( std::ostream* stream, const std::string& element_name, const std::string& name, const std::string& value, const std::string& indent, bool comma) { const std::vector& allowed_names = GetReservedOutputAttributesForElement(element_name); GTEST_CHECK_(std::find(allowed_names.begin(), allowed_names.end(), name) != allowed_names.end()) << "Key \"" << name << "\" is not allowed for value \"" << element_name << "\"."; *stream << indent << "\"" << name << "\": \"" << EscapeJson(value) << "\""; if (comma) *stream << ",\n"; } void JsonUnitTestResultPrinter::OutputJsonKey( std::ostream* stream, const std::string& element_name, const std::string& name, int value, const std::string& indent, bool comma) { const std::vector& allowed_names = GetReservedOutputAttributesForElement(element_name); GTEST_CHECK_(std::find(allowed_names.begin(), allowed_names.end(), name) != allowed_names.end()) << "Key \"" << name << "\" is not allowed for value \"" << element_name << "\"."; *stream << indent << "\"" << name << "\": " << StreamableToString(value); if (comma) *stream << ",\n"; } // Prints a JSON representation of a TestInfo object. void JsonUnitTestResultPrinter::OutputJsonTestInfo(::std::ostream* stream, const char* test_suite_name, const TestInfo& test_info) { const TestResult& result = *test_info.result(); const std::string kTestsuite = "testcase"; const std::string kIndent = Indent(10); *stream << Indent(8) << "{\n"; OutputJsonKey(stream, kTestsuite, "name", test_info.name(), kIndent); if (test_info.value_param() != nullptr) { OutputJsonKey(stream, kTestsuite, "value_param", test_info.value_param(), kIndent); } if (test_info.type_param() != nullptr) { OutputJsonKey(stream, kTestsuite, "type_param", test_info.type_param(), kIndent); } if (GTEST_FLAG(list_tests)) { OutputJsonKey(stream, kTestsuite, "file", test_info.file(), kIndent); OutputJsonKey(stream, kTestsuite, "line", test_info.line(), kIndent, false); *stream << "\n" << Indent(8) << "}"; return; } OutputJsonKey(stream, kTestsuite, "status", test_info.should_run() ? "RUN" : "NOTRUN", kIndent); OutputJsonKey(stream, kTestsuite, "result", test_info.should_run() ? (result.Skipped() ? "SKIPPED" : "COMPLETED") : "SUPPRESSED", kIndent); OutputJsonKey(stream, kTestsuite, "timestamp", FormatEpochTimeInMillisAsRFC3339(result.start_timestamp()), kIndent); OutputJsonKey(stream, kTestsuite, "time", FormatTimeInMillisAsDuration(result.elapsed_time()), kIndent); OutputJsonKey(stream, kTestsuite, "classname", test_suite_name, kIndent, false); *stream << TestPropertiesAsJson(result, kIndent); int failures = 0; for (int i = 0; i < result.total_part_count(); ++i) { const TestPartResult& part = result.GetTestPartResult(i); if (part.failed()) { *stream << ",\n"; if (++failures == 1) { *stream << kIndent << "\"" << "failures" << "\": [\n"; } const std::string location = internal::FormatCompilerIndependentFileLocation(part.file_name(), part.line_number()); const std::string message = EscapeJson(location + "\n" + part.message()); *stream << kIndent << " {\n" << kIndent << " \"failure\": \"" << message << "\",\n" << kIndent << " \"type\": \"\"\n" << kIndent << " }"; } } if (failures > 0) *stream << "\n" << kIndent << "]"; *stream << "\n" << Indent(8) << "}"; } // Prints an JSON representation of a TestSuite object void JsonUnitTestResultPrinter::PrintJsonTestSuite( std::ostream* stream, const TestSuite& test_suite) { const std::string kTestsuite = "testsuite"; const std::string kIndent = Indent(6); *stream << Indent(4) << "{\n"; OutputJsonKey(stream, kTestsuite, "name", test_suite.name(), kIndent); OutputJsonKey(stream, kTestsuite, "tests", test_suite.reportable_test_count(), kIndent); if (!GTEST_FLAG(list_tests)) { OutputJsonKey(stream, kTestsuite, "failures", test_suite.failed_test_count(), kIndent); OutputJsonKey(stream, kTestsuite, "disabled", test_suite.reportable_disabled_test_count(), kIndent); OutputJsonKey(stream, kTestsuite, "errors", 0, kIndent); OutputJsonKey( stream, kTestsuite, "timestamp", FormatEpochTimeInMillisAsRFC3339(test_suite.start_timestamp()), kIndent); OutputJsonKey(stream, kTestsuite, "time", FormatTimeInMillisAsDuration(test_suite.elapsed_time()), kIndent, false); *stream << TestPropertiesAsJson(test_suite.ad_hoc_test_result(), kIndent) << ",\n"; } *stream << kIndent << "\"" << kTestsuite << "\": [\n"; bool comma = false; for (int i = 0; i < test_suite.total_test_count(); ++i) { if (test_suite.GetTestInfo(i)->is_reportable()) { if (comma) { *stream << ",\n"; } else { comma = true; } OutputJsonTestInfo(stream, test_suite.name(), *test_suite.GetTestInfo(i)); } } *stream << "\n" << kIndent << "]\n" << Indent(4) << "}"; } // Prints a JSON summary of unit_test to output stream out. void JsonUnitTestResultPrinter::PrintJsonUnitTest(std::ostream* stream, const UnitTest& unit_test) { const std::string kTestsuites = "testsuites"; const std::string kIndent = Indent(2); *stream << "{\n"; OutputJsonKey(stream, kTestsuites, "tests", unit_test.reportable_test_count(), kIndent); OutputJsonKey(stream, kTestsuites, "failures", unit_test.failed_test_count(), kIndent); OutputJsonKey(stream, kTestsuites, "disabled", unit_test.reportable_disabled_test_count(), kIndent); OutputJsonKey(stream, kTestsuites, "errors", 0, kIndent); if (GTEST_FLAG(shuffle)) { OutputJsonKey(stream, kTestsuites, "random_seed", unit_test.random_seed(), kIndent); } OutputJsonKey(stream, kTestsuites, "timestamp", FormatEpochTimeInMillisAsRFC3339(unit_test.start_timestamp()), kIndent); OutputJsonKey(stream, kTestsuites, "time", FormatTimeInMillisAsDuration(unit_test.elapsed_time()), kIndent, false); *stream << TestPropertiesAsJson(unit_test.ad_hoc_test_result(), kIndent) << ",\n"; OutputJsonKey(stream, kTestsuites, "name", "AllTests", kIndent); *stream << kIndent << "\"" << kTestsuites << "\": [\n"; bool comma = false; for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { if (unit_test.GetTestSuite(i)->reportable_test_count() > 0) { if (comma) { *stream << ",\n"; } else { comma = true; } PrintJsonTestSuite(stream, *unit_test.GetTestSuite(i)); } } *stream << "\n" << kIndent << "]\n" << "}\n"; } void JsonUnitTestResultPrinter::PrintJsonTestList( std::ostream* stream, const std::vector& test_suites) { const std::string kTestsuites = "testsuites"; const std::string kIndent = Indent(2); *stream << "{\n"; int total_tests = 0; for (auto test_suite : test_suites) { total_tests += test_suite->total_test_count(); } OutputJsonKey(stream, kTestsuites, "tests", total_tests, kIndent); OutputJsonKey(stream, kTestsuites, "name", "AllTests", kIndent); *stream << kIndent << "\"" << kTestsuites << "\": [\n"; for (size_t i = 0; i < test_suites.size(); ++i) { if (i != 0) { *stream << ",\n"; } PrintJsonTestSuite(stream, *test_suites[i]); } *stream << "\n" << kIndent << "]\n" << "}\n"; } // Produces a string representing the test properties in a result as // a JSON dictionary. std::string JsonUnitTestResultPrinter::TestPropertiesAsJson( const TestResult& result, const std::string& indent) { Message attributes; for (int i = 0; i < result.test_property_count(); ++i) { const TestProperty& property = result.GetTestProperty(i); attributes << ",\n" << indent << "\"" << property.key() << "\": " << "\"" << EscapeJson(property.value()) << "\""; } return attributes.GetString(); } // End JsonUnitTestResultPrinter #if GTEST_CAN_STREAM_RESULTS_ // Checks if str contains '=', '&', '%' or '\n' characters. If yes, // replaces them by "%xx" where xx is their hexadecimal value. For // example, replaces "=" with "%3D". This algorithm is O(strlen(str)) // in both time and space -- important as the input str may contain an // arbitrarily long test failure message and stack trace. std::string StreamingListener::UrlEncode(const char* str) { std::string result; result.reserve(strlen(str) + 1); for (char ch = *str; ch != '\0'; ch = *++str) { switch (ch) { case '%': case '=': case '&': case '\n': result.append("%" + String::FormatByte(static_cast(ch))); break; default: result.push_back(ch); break; } } return result; } void StreamingListener::SocketWriter::MakeConnection() { GTEST_CHECK_(sockfd_ == -1) << "MakeConnection() can't be called when there is already a connection."; addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; // To allow both IPv4 and IPv6 addresses. hints.ai_socktype = SOCK_STREAM; addrinfo* servinfo = nullptr; // Use the getaddrinfo() to get a linked list of IP addresses for // the given host name. const int error_num = getaddrinfo( host_name_.c_str(), port_num_.c_str(), &hints, &servinfo); if (error_num != 0) { GTEST_LOG_(WARNING) << "stream_result_to: getaddrinfo() failed: " << gai_strerror(error_num); } // Loop through all the results and connect to the first we can. for (addrinfo* cur_addr = servinfo; sockfd_ == -1 && cur_addr != nullptr; cur_addr = cur_addr->ai_next) { sockfd_ = socket( cur_addr->ai_family, cur_addr->ai_socktype, cur_addr->ai_protocol); if (sockfd_ != -1) { // Connect the client socket to the server socket. if (connect(sockfd_, cur_addr->ai_addr, cur_addr->ai_addrlen) == -1) { close(sockfd_); sockfd_ = -1; } } } freeaddrinfo(servinfo); // all done with this structure if (sockfd_ == -1) { GTEST_LOG_(WARNING) << "stream_result_to: failed to connect to " << host_name_ << ":" << port_num_; } } // End of class Streaming Listener #endif // GTEST_CAN_STREAM_RESULTS__ // class OsStackTraceGetter const char* const OsStackTraceGetterInterface::kElidedFramesMarker = "... " GTEST_NAME_ " internal frames ..."; std::string OsStackTraceGetter::CurrentStackTrace(int max_depth, int skip_count) GTEST_LOCK_EXCLUDED_(mutex_) { #if GTEST_HAS_ABSL std::string result; if (max_depth <= 0) { return result; } max_depth = std::min(max_depth, kMaxStackTraceDepth); std::vector raw_stack(max_depth); // Skips the frames requested by the caller, plus this function. const int raw_stack_size = absl::GetStackTrace(&raw_stack[0], max_depth, skip_count + 1); void* caller_frame = nullptr; { MutexLock lock(&mutex_); caller_frame = caller_frame_; } for (int i = 0; i < raw_stack_size; ++i) { if (raw_stack[i] == caller_frame && !GTEST_FLAG(show_internal_stack_frames)) { // Add a marker to the trace and stop adding frames. absl::StrAppend(&result, kElidedFramesMarker, "\n"); break; } char tmp[1024]; const char* symbol = "(unknown)"; if (absl::Symbolize(raw_stack[i], tmp, sizeof(tmp))) { symbol = tmp; } char line[1024]; snprintf(line, sizeof(line), " %p: %s\n", raw_stack[i], symbol); result += line; } return result; #else // !GTEST_HAS_ABSL static_cast(max_depth); static_cast(skip_count); return ""; #endif // GTEST_HAS_ABSL } void OsStackTraceGetter::UponLeavingGTest() GTEST_LOCK_EXCLUDED_(mutex_) { #if GTEST_HAS_ABSL void* caller_frame = nullptr; if (absl::GetStackTrace(&caller_frame, 1, 3) <= 0) { caller_frame = nullptr; } MutexLock lock(&mutex_); caller_frame_ = caller_frame; #endif // GTEST_HAS_ABSL } // A helper class that creates the premature-exit file in its // constructor and deletes the file in its destructor. class ScopedPrematureExitFile { public: explicit ScopedPrematureExitFile(const char* premature_exit_filepath) : premature_exit_filepath_(premature_exit_filepath ? premature_exit_filepath : "") { // If a path to the premature-exit file is specified... if (!premature_exit_filepath_.empty()) { // create the file with a single "0" character in it. I/O // errors are ignored as there's nothing better we can do and we // don't want to fail the test because of this. FILE* pfile = posix::FOpen(premature_exit_filepath, "w"); fwrite("0", 1, 1, pfile); fclose(pfile); } } ~ScopedPrematureExitFile() { if (!premature_exit_filepath_.empty()) { int retval = remove(premature_exit_filepath_.c_str()); if (retval) { GTEST_LOG_(ERROR) << "Failed to remove premature exit filepath \"" << premature_exit_filepath_ << "\" with error " << retval; } } } private: const std::string premature_exit_filepath_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedPrematureExitFile); }; } // namespace internal // class TestEventListeners TestEventListeners::TestEventListeners() : repeater_(new internal::TestEventRepeater()), default_result_printer_(nullptr), default_xml_generator_(nullptr) {} TestEventListeners::~TestEventListeners() { delete repeater_; } // Returns the standard listener responsible for the default console // output. Can be removed from the listeners list to shut down default // console output. Note that removing this object from the listener list // with Release transfers its ownership to the user. void TestEventListeners::Append(TestEventListener* listener) { repeater_->Append(listener); } // Removes the given event listener from the list and returns it. It then // becomes the caller's responsibility to delete the listener. Returns // NULL if the listener is not found in the list. TestEventListener* TestEventListeners::Release(TestEventListener* listener) { if (listener == default_result_printer_) default_result_printer_ = nullptr; else if (listener == default_xml_generator_) default_xml_generator_ = nullptr; return repeater_->Release(listener); } // Returns repeater that broadcasts the TestEventListener events to all // subscribers. TestEventListener* TestEventListeners::repeater() { return repeater_; } // Sets the default_result_printer attribute to the provided listener. // The listener is also added to the listener list and previous // default_result_printer is removed from it and deleted. The listener can // also be NULL in which case it will not be added to the list. Does // nothing if the previous and the current listener objects are the same. void TestEventListeners::SetDefaultResultPrinter(TestEventListener* listener) { if (default_result_printer_ != listener) { // It is an error to pass this method a listener that is already in the // list. delete Release(default_result_printer_); default_result_printer_ = listener; if (listener != nullptr) Append(listener); } } // Sets the default_xml_generator attribute to the provided listener. The // listener is also added to the listener list and previous // default_xml_generator is removed from it and deleted. The listener can // also be NULL in which case it will not be added to the list. Does // nothing if the previous and the current listener objects are the same. void TestEventListeners::SetDefaultXmlGenerator(TestEventListener* listener) { if (default_xml_generator_ != listener) { // It is an error to pass this method a listener that is already in the // list. delete Release(default_xml_generator_); default_xml_generator_ = listener; if (listener != nullptr) Append(listener); } } // Controls whether events will be forwarded by the repeater to the // listeners in the list. bool TestEventListeners::EventForwardingEnabled() const { return repeater_->forwarding_enabled(); } void TestEventListeners::SuppressEventForwarding() { repeater_->set_forwarding_enabled(false); } // class UnitTest // Gets the singleton UnitTest object. The first time this method is // called, a UnitTest object is constructed and returned. Consecutive // calls will return the same object. // // We don't protect this under mutex_ as a user is not supposed to // call this before main() starts, from which point on the return // value will never change. UnitTest* UnitTest::GetInstance() { // CodeGear C++Builder insists on a public destructor for the // default implementation. Use this implementation to keep good OO // design with private destructor. #if defined(__BORLANDC__) static UnitTest* const instance = new UnitTest; return instance; #else static UnitTest instance; return &instance; #endif // defined(__BORLANDC__) } // Gets the number of successful test suites. int UnitTest::successful_test_suite_count() const { return impl()->successful_test_suite_count(); } // Gets the number of failed test suites. int UnitTest::failed_test_suite_count() const { return impl()->failed_test_suite_count(); } // Gets the number of all test suites. int UnitTest::total_test_suite_count() const { return impl()->total_test_suite_count(); } // Gets the number of all test suites that contain at least one test // that should run. int UnitTest::test_suite_to_run_count() const { return impl()->test_suite_to_run_count(); } // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ int UnitTest::successful_test_case_count() const { return impl()->successful_test_suite_count(); } int UnitTest::failed_test_case_count() const { return impl()->failed_test_suite_count(); } int UnitTest::total_test_case_count() const { return impl()->total_test_suite_count(); } int UnitTest::test_case_to_run_count() const { return impl()->test_suite_to_run_count(); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Gets the number of successful tests. int UnitTest::successful_test_count() const { return impl()->successful_test_count(); } // Gets the number of skipped tests. int UnitTest::skipped_test_count() const { return impl()->skipped_test_count(); } // Gets the number of failed tests. int UnitTest::failed_test_count() const { return impl()->failed_test_count(); } // Gets the number of disabled tests that will be reported in the XML report. int UnitTest::reportable_disabled_test_count() const { return impl()->reportable_disabled_test_count(); } // Gets the number of disabled tests. int UnitTest::disabled_test_count() const { return impl()->disabled_test_count(); } // Gets the number of tests to be printed in the XML report. int UnitTest::reportable_test_count() const { return impl()->reportable_test_count(); } // Gets the number of all tests. int UnitTest::total_test_count() const { return impl()->total_test_count(); } // Gets the number of tests that should run. int UnitTest::test_to_run_count() const { return impl()->test_to_run_count(); } // Gets the time of the test program start, in ms from the start of the // UNIX epoch. internal::TimeInMillis UnitTest::start_timestamp() const { return impl()->start_timestamp(); } // Gets the elapsed time, in milliseconds. internal::TimeInMillis UnitTest::elapsed_time() const { return impl()->elapsed_time(); } // Returns true if and only if the unit test passed (i.e. all test suites // passed). bool UnitTest::Passed() const { return impl()->Passed(); } // Returns true if and only if the unit test failed (i.e. some test suite // failed or something outside of all tests failed). bool UnitTest::Failed() const { return impl()->Failed(); } // Gets the i-th test suite among all the test suites. i can range from 0 to // total_test_suite_count() - 1. If i is not in that range, returns NULL. const TestSuite* UnitTest::GetTestSuite(int i) const { return impl()->GetTestSuite(i); } // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ const TestCase* UnitTest::GetTestCase(int i) const { return impl()->GetTestCase(i); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Returns the TestResult containing information on test failures and // properties logged outside of individual test suites. const TestResult& UnitTest::ad_hoc_test_result() const { return *impl()->ad_hoc_test_result(); } // Gets the i-th test suite among all the test suites. i can range from 0 to // total_test_suite_count() - 1. If i is not in that range, returns NULL. TestSuite* UnitTest::GetMutableTestSuite(int i) { return impl()->GetMutableSuiteCase(i); } // Returns the list of event listeners that can be used to track events // inside Google Test. TestEventListeners& UnitTest::listeners() { return *impl()->listeners(); } // Registers and returns a global test environment. When a test // program is run, all global test environments will be set-up in the // order they were registered. After all tests in the program have // finished, all global test environments will be torn-down in the // *reverse* order they were registered. // // The UnitTest object takes ownership of the given environment. // // We don't protect this under mutex_, as we only support calling it // from the main thread. Environment* UnitTest::AddEnvironment(Environment* env) { if (env == nullptr) { return nullptr; } impl_->environments().push_back(env); return env; } // Adds a TestPartResult to the current TestResult object. All Google Test // assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) eventually call // this to report their results. The user code should use the // assertion macros instead of calling this directly. void UnitTest::AddTestPartResult( TestPartResult::Type result_type, const char* file_name, int line_number, const std::string& message, const std::string& os_stack_trace) GTEST_LOCK_EXCLUDED_(mutex_) { Message msg; msg << message; internal::MutexLock lock(&mutex_); if (impl_->gtest_trace_stack().size() > 0) { msg << "\n" << GTEST_NAME_ << " trace:"; for (size_t i = impl_->gtest_trace_stack().size(); i > 0; --i) { const internal::TraceInfo& trace = impl_->gtest_trace_stack()[i - 1]; msg << "\n" << internal::FormatFileLocation(trace.file, trace.line) << " " << trace.message; } } if (os_stack_trace.c_str() != nullptr && !os_stack_trace.empty()) { msg << internal::kStackTraceMarker << os_stack_trace; } const TestPartResult result = TestPartResult( result_type, file_name, line_number, msg.GetString().c_str()); impl_->GetTestPartResultReporterForCurrentThread()-> ReportTestPartResult(result); if (result_type != TestPartResult::kSuccess && result_type != TestPartResult::kSkip) { // gtest_break_on_failure takes precedence over // gtest_throw_on_failure. This allows a user to set the latter // in the code (perhaps in order to use Google Test assertions // with another testing framework) and specify the former on the // command line for debugging. if (GTEST_FLAG(break_on_failure)) { #if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT // Using DebugBreak on Windows allows gtest to still break into a debugger // when a failure happens and both the --gtest_break_on_failure and // the --gtest_catch_exceptions flags are specified. DebugBreak(); #elif (!defined(__native_client__)) && \ ((defined(__clang__) || defined(__GNUC__)) && \ (defined(__x86_64__) || defined(__i386__))) // with clang/gcc we can achieve the same effect on x86 by invoking int3 asm("int3"); #else // Dereference nullptr through a volatile pointer to prevent the compiler // from removing. We use this rather than abort() or __builtin_trap() for // portability: some debuggers don't correctly trap abort(). *static_cast(nullptr) = 1; #endif // GTEST_OS_WINDOWS } else if (GTEST_FLAG(throw_on_failure)) { #if GTEST_HAS_EXCEPTIONS throw internal::GoogleTestFailureException(result); #else // We cannot call abort() as it generates a pop-up in debug mode // that cannot be suppressed in VC 7.1 or below. exit(1); #endif } } } // Adds a TestProperty to the current TestResult object when invoked from // inside a test, to current TestSuite's ad_hoc_test_result_ when invoked // from SetUpTestSuite or TearDownTestSuite, or to the global property set // when invoked elsewhere. If the result already contains a property with // the same key, the value will be updated. void UnitTest::RecordProperty(const std::string& key, const std::string& value) { impl_->RecordProperty(TestProperty(key, value)); } // Runs all tests in this UnitTest object and prints the result. // Returns 0 if successful, or 1 otherwise. // // We don't protect this under mutex_, as we only support calling it // from the main thread. int UnitTest::Run() { const bool in_death_test_child_process = internal::GTEST_FLAG(internal_run_death_test).length() > 0; // Google Test implements this protocol for catching that a test // program exits before returning control to Google Test: // // 1. Upon start, Google Test creates a file whose absolute path // is specified by the environment variable // TEST_PREMATURE_EXIT_FILE. // 2. When Google Test has finished its work, it deletes the file. // // This allows a test runner to set TEST_PREMATURE_EXIT_FILE before // running a Google-Test-based test program and check the existence // of the file at the end of the test execution to see if it has // exited prematurely. // If we are in the child process of a death test, don't // create/delete the premature exit file, as doing so is unnecessary // and will confuse the parent process. Otherwise, create/delete // the file upon entering/leaving this function. If the program // somehow exits before this function has a chance to return, the // premature-exit file will be left undeleted, causing a test runner // that understands the premature-exit-file protocol to report the // test as having failed. const internal::ScopedPrematureExitFile premature_exit_file( in_death_test_child_process ? nullptr : internal::posix::GetEnv("TEST_PREMATURE_EXIT_FILE")); // Captures the value of GTEST_FLAG(catch_exceptions). This value will be // used for the duration of the program. impl()->set_catch_exceptions(GTEST_FLAG(catch_exceptions)); #if GTEST_OS_WINDOWS // Either the user wants Google Test to catch exceptions thrown by the // tests or this is executing in the context of death test child // process. In either case the user does not want to see pop-up dialogs // about crashes - they are expected. if (impl()->catch_exceptions() || in_death_test_child_process) { # if !GTEST_OS_WINDOWS_MOBILE && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT // SetErrorMode doesn't exist on CE. SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); # endif // !GTEST_OS_WINDOWS_MOBILE # if (defined(_MSC_VER) || GTEST_OS_WINDOWS_MINGW) && !GTEST_OS_WINDOWS_MOBILE // Death test children can be terminated with _abort(). On Windows, // _abort() can show a dialog with a warning message. This forces the // abort message to go to stderr instead. _set_error_mode(_OUT_TO_STDERR); # endif # if defined(_MSC_VER) && !GTEST_OS_WINDOWS_MOBILE // In the debug version, Visual Studio pops up a separate dialog // offering a choice to debug the aborted program. We need to suppress // this dialog or it will pop up for every EXPECT/ASSERT_DEATH statement // executed. Google Test will notify the user of any unexpected // failure via stderr. if (!GTEST_FLAG(break_on_failure)) _set_abort_behavior( 0x0, // Clear the following flags: _WRITE_ABORT_MSG | _CALL_REPORTFAULT); // pop-up window, core dump. # endif // In debug mode, the Windows CRT can crash with an assertion over invalid // input (e.g. passing an invalid file descriptor). The default handling // for these assertions is to pop up a dialog and wait for user input. // Instead ask the CRT to dump such assertions to stderr non-interactively. if (!IsDebuggerPresent()) { (void)_CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); (void)_CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); } } #endif // GTEST_OS_WINDOWS return internal::HandleExceptionsInMethodIfSupported( impl(), &internal::UnitTestImpl::RunAllTests, "auxiliary test code (environments or event listeners)") ? 0 : 1; } // Returns the working directory when the first TEST() or TEST_F() was // executed. const char* UnitTest::original_working_dir() const { return impl_->original_working_dir_.c_str(); } // Returns the TestSuite object for the test that's currently running, // or NULL if no test is running. const TestSuite* UnitTest::current_test_suite() const GTEST_LOCK_EXCLUDED_(mutex_) { internal::MutexLock lock(&mutex_); return impl_->current_test_suite(); } // Legacy API is still available but deprecated #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ const TestCase* UnitTest::current_test_case() const GTEST_LOCK_EXCLUDED_(mutex_) { internal::MutexLock lock(&mutex_); return impl_->current_test_suite(); } #endif // Returns the TestInfo object for the test that's currently running, // or NULL if no test is running. const TestInfo* UnitTest::current_test_info() const GTEST_LOCK_EXCLUDED_(mutex_) { internal::MutexLock lock(&mutex_); return impl_->current_test_info(); } // Returns the random seed used at the start of the current test run. int UnitTest::random_seed() const { return impl_->random_seed(); } // Returns ParameterizedTestSuiteRegistry object used to keep track of // value-parameterized tests and instantiate and register them. internal::ParameterizedTestSuiteRegistry& UnitTest::parameterized_test_registry() GTEST_LOCK_EXCLUDED_(mutex_) { return impl_->parameterized_test_registry(); } // Creates an empty UnitTest. UnitTest::UnitTest() { impl_ = new internal::UnitTestImpl(this); } // Destructor of UnitTest. UnitTest::~UnitTest() { delete impl_; } // Pushes a trace defined by SCOPED_TRACE() on to the per-thread // Google Test trace stack. void UnitTest::PushGTestTrace(const internal::TraceInfo& trace) GTEST_LOCK_EXCLUDED_(mutex_) { internal::MutexLock lock(&mutex_); impl_->gtest_trace_stack().push_back(trace); } // Pops a trace from the per-thread Google Test trace stack. void UnitTest::PopGTestTrace() GTEST_LOCK_EXCLUDED_(mutex_) { internal::MutexLock lock(&mutex_); impl_->gtest_trace_stack().pop_back(); } namespace internal { UnitTestImpl::UnitTestImpl(UnitTest* parent) : parent_(parent), GTEST_DISABLE_MSC_WARNINGS_PUSH_(4355 /* using this in initializer */) default_global_test_part_result_reporter_(this), default_per_thread_test_part_result_reporter_(this), GTEST_DISABLE_MSC_WARNINGS_POP_() global_test_part_result_repoter_( &default_global_test_part_result_reporter_), per_thread_test_part_result_reporter_( &default_per_thread_test_part_result_reporter_), parameterized_test_registry_(), parameterized_tests_registered_(false), last_death_test_suite_(-1), current_test_suite_(nullptr), current_test_info_(nullptr), ad_hoc_test_result_(), os_stack_trace_getter_(nullptr), post_flag_parse_init_performed_(false), random_seed_(0), // Will be overridden by the flag before first use. random_(0), // Will be reseeded before first use. start_timestamp_(0), elapsed_time_(0), #if GTEST_HAS_DEATH_TEST death_test_factory_(new DefaultDeathTestFactory), #endif // Will be overridden by the flag before first use. catch_exceptions_(false) { listeners()->SetDefaultResultPrinter(new PrettyUnitTestResultPrinter); } UnitTestImpl::~UnitTestImpl() { // Deletes every TestSuite. ForEach(test_suites_, internal::Delete); // Deletes every Environment. ForEach(environments_, internal::Delete); delete os_stack_trace_getter_; } // Adds a TestProperty to the current TestResult object when invoked in a // context of a test, to current test suite's ad_hoc_test_result when invoke // from SetUpTestSuite/TearDownTestSuite, or to the global property set // otherwise. If the result already contains a property with the same key, // the value will be updated. void UnitTestImpl::RecordProperty(const TestProperty& test_property) { std::string xml_element; TestResult* test_result; // TestResult appropriate for property recording. if (current_test_info_ != nullptr) { xml_element = "testcase"; test_result = &(current_test_info_->result_); } else if (current_test_suite_ != nullptr) { xml_element = "testsuite"; test_result = &(current_test_suite_->ad_hoc_test_result_); } else { xml_element = "testsuites"; test_result = &ad_hoc_test_result_; } test_result->RecordProperty(xml_element, test_property); } #if GTEST_HAS_DEATH_TEST // Disables event forwarding if the control is currently in a death test // subprocess. Must not be called before InitGoogleTest. void UnitTestImpl::SuppressTestEventsIfInSubprocess() { if (internal_run_death_test_flag_.get() != nullptr) listeners()->SuppressEventForwarding(); } #endif // GTEST_HAS_DEATH_TEST // Initializes event listeners performing XML output as specified by // UnitTestOptions. Must not be called before InitGoogleTest. void UnitTestImpl::ConfigureXmlOutput() { const std::string& output_format = UnitTestOptions::GetOutputFormat(); if (output_format == "xml") { listeners()->SetDefaultXmlGenerator(new XmlUnitTestResultPrinter( UnitTestOptions::GetAbsolutePathToOutputFile().c_str())); } else if (output_format == "json") { listeners()->SetDefaultXmlGenerator(new JsonUnitTestResultPrinter( UnitTestOptions::GetAbsolutePathToOutputFile().c_str())); } else if (output_format != "") { GTEST_LOG_(WARNING) << "WARNING: unrecognized output format \"" << output_format << "\" ignored."; } } #if GTEST_CAN_STREAM_RESULTS_ // Initializes event listeners for streaming test results in string form. // Must not be called before InitGoogleTest. void UnitTestImpl::ConfigureStreamingOutput() { const std::string& target = GTEST_FLAG(stream_result_to); if (!target.empty()) { const size_t pos = target.find(':'); if (pos != std::string::npos) { listeners()->Append(new StreamingListener(target.substr(0, pos), target.substr(pos+1))); } else { GTEST_LOG_(WARNING) << "unrecognized streaming target \"" << target << "\" ignored."; } } } #endif // GTEST_CAN_STREAM_RESULTS_ // Performs initialization dependent upon flag values obtained in // ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to // ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest // this function is also called from RunAllTests. Since this function can be // called more than once, it has to be idempotent. void UnitTestImpl::PostFlagParsingInit() { // Ensures that this function does not execute more than once. if (!post_flag_parse_init_performed_) { post_flag_parse_init_performed_ = true; #if defined(GTEST_CUSTOM_TEST_EVENT_LISTENER_) // Register to send notifications about key process state changes. listeners()->Append(new GTEST_CUSTOM_TEST_EVENT_LISTENER_()); #endif // defined(GTEST_CUSTOM_TEST_EVENT_LISTENER_) #if GTEST_HAS_DEATH_TEST InitDeathTestSubprocessControlInfo(); SuppressTestEventsIfInSubprocess(); #endif // GTEST_HAS_DEATH_TEST // Registers parameterized tests. This makes parameterized tests // available to the UnitTest reflection API without running // RUN_ALL_TESTS. RegisterParameterizedTests(); // Configures listeners for XML output. This makes it possible for users // to shut down the default XML output before invoking RUN_ALL_TESTS. ConfigureXmlOutput(); #if GTEST_CAN_STREAM_RESULTS_ // Configures listeners for streaming test results to the specified server. ConfigureStreamingOutput(); #endif // GTEST_CAN_STREAM_RESULTS_ #if GTEST_HAS_ABSL if (GTEST_FLAG(install_failure_signal_handler)) { absl::FailureSignalHandlerOptions options; absl::InstallFailureSignalHandler(options); } #endif // GTEST_HAS_ABSL } } // A predicate that checks the name of a TestSuite against a known // value. // // This is used for implementation of the UnitTest class only. We put // it in the anonymous namespace to prevent polluting the outer // namespace. // // TestSuiteNameIs is copyable. class TestSuiteNameIs { public: // Constructor. explicit TestSuiteNameIs(const std::string& name) : name_(name) {} // Returns true if and only if the name of test_suite matches name_. bool operator()(const TestSuite* test_suite) const { return test_suite != nullptr && strcmp(test_suite->name(), name_.c_str()) == 0; } private: std::string name_; }; // Finds and returns a TestSuite with the given name. If one doesn't // exist, creates one and returns it. It's the CALLER'S // RESPONSIBILITY to ensure that this function is only called WHEN THE // TESTS ARE NOT SHUFFLED. // // Arguments: // // test_suite_name: name of the test suite // type_param: the name of the test suite's type parameter, or NULL if // this is not a typed or a type-parameterized test suite. // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite TestSuite* UnitTestImpl::GetTestSuite( const char* test_suite_name, const char* type_param, internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc) { // Can we find a TestSuite with the given name? const auto test_suite = std::find_if(test_suites_.rbegin(), test_suites_.rend(), TestSuiteNameIs(test_suite_name)); if (test_suite != test_suites_.rend()) return *test_suite; // No. Let's create one. auto* const new_test_suite = new TestSuite(test_suite_name, type_param, set_up_tc, tear_down_tc); // Is this a death test suite? if (internal::UnitTestOptions::MatchesFilter(test_suite_name, kDeathTestSuiteFilter)) { // Yes. Inserts the test suite after the last death test suite // defined so far. This only works when the test suites haven't // been shuffled. Otherwise we may end up running a death test // after a non-death test. ++last_death_test_suite_; test_suites_.insert(test_suites_.begin() + last_death_test_suite_, new_test_suite); } else { // No. Appends to the end of the list. test_suites_.push_back(new_test_suite); } test_suite_indices_.push_back(static_cast(test_suite_indices_.size())); return new_test_suite; } // Helpers for setting up / tearing down the given environment. They // are for use in the ForEach() function. static void SetUpEnvironment(Environment* env) { env->SetUp(); } static void TearDownEnvironment(Environment* env) { env->TearDown(); } // Runs all tests in this UnitTest object, prints the result, and // returns true if all tests are successful. If any exception is // thrown during a test, the test is considered to be failed, but the // rest of the tests will still be run. // // When parameterized tests are enabled, it expands and registers // parameterized tests first in RegisterParameterizedTests(). // All other functions called from RunAllTests() may safely assume that // parameterized tests are ready to be counted and run. bool UnitTestImpl::RunAllTests() { // True if and only if Google Test is initialized before RUN_ALL_TESTS() is // called. const bool gtest_is_initialized_before_run_all_tests = GTestIsInitialized(); // Do not run any test if the --help flag was specified. if (g_help_flag) return true; // Repeats the call to the post-flag parsing initialization in case the // user didn't call InitGoogleTest. PostFlagParsingInit(); // Even if sharding is not on, test runners may want to use the // GTEST_SHARD_STATUS_FILE to query whether the test supports the sharding // protocol. internal::WriteToShardStatusFileIfNeeded(); // True if and only if we are in a subprocess for running a thread-safe-style // death test. bool in_subprocess_for_death_test = false; #if GTEST_HAS_DEATH_TEST in_subprocess_for_death_test = (internal_run_death_test_flag_.get() != nullptr); # if defined(GTEST_EXTRA_DEATH_TEST_CHILD_SETUP_) if (in_subprocess_for_death_test) { GTEST_EXTRA_DEATH_TEST_CHILD_SETUP_(); } # endif // defined(GTEST_EXTRA_DEATH_TEST_CHILD_SETUP_) #endif // GTEST_HAS_DEATH_TEST const bool should_shard = ShouldShard(kTestTotalShards, kTestShardIndex, in_subprocess_for_death_test); // Compares the full test names with the filter to decide which // tests to run. const bool has_tests_to_run = FilterTests(should_shard ? HONOR_SHARDING_PROTOCOL : IGNORE_SHARDING_PROTOCOL) > 0; // Lists the tests and exits if the --gtest_list_tests flag was specified. if (GTEST_FLAG(list_tests)) { // This must be called *after* FilterTests() has been called. ListTestsMatchingFilter(); return true; } random_seed_ = GTEST_FLAG(shuffle) ? GetRandomSeedFromFlag(GTEST_FLAG(random_seed)) : 0; // True if and only if at least one test has failed. bool failed = false; TestEventListener* repeater = listeners()->repeater(); start_timestamp_ = GetTimeInMillis(); repeater->OnTestProgramStart(*parent_); // How many times to repeat the tests? We don't want to repeat them // when we are inside the subprocess of a death test. const int repeat = in_subprocess_for_death_test ? 1 : GTEST_FLAG(repeat); // Repeats forever if the repeat count is negative. const bool gtest_repeat_forever = repeat < 0; for (int i = 0; gtest_repeat_forever || i != repeat; i++) { // We want to preserve failures generated by ad-hoc test // assertions executed before RUN_ALL_TESTS(). ClearNonAdHocTestResult(); const TimeInMillis start = GetTimeInMillis(); // Shuffles test suites and tests if requested. if (has_tests_to_run && GTEST_FLAG(shuffle)) { random()->Reseed(static_cast(random_seed_)); // This should be done before calling OnTestIterationStart(), // such that a test event listener can see the actual test order // in the event. ShuffleTests(); } // Tells the unit test event listeners that the tests are about to start. repeater->OnTestIterationStart(*parent_, i); // Runs each test suite if there is at least one test to run. if (has_tests_to_run) { // Sets up all environments beforehand. repeater->OnEnvironmentsSetUpStart(*parent_); ForEach(environments_, SetUpEnvironment); repeater->OnEnvironmentsSetUpEnd(*parent_); // Runs the tests only if there was no fatal failure or skip triggered // during global set-up. if (Test::IsSkipped()) { // Emit diagnostics when global set-up calls skip, as it will not be // emitted by default. TestResult& test_result = *internal::GetUnitTestImpl()->current_test_result(); for (int j = 0; j < test_result.total_part_count(); ++j) { const TestPartResult& test_part_result = test_result.GetTestPartResult(j); if (test_part_result.type() == TestPartResult::kSkip) { const std::string& result = test_part_result.message(); printf("%s\n", result.c_str()); } } fflush(stdout); } else if (!Test::HasFatalFailure()) { for (int test_index = 0; test_index < total_test_suite_count(); test_index++) { GetMutableSuiteCase(test_index)->Run(); } } // Tears down all environments in reverse order afterwards. repeater->OnEnvironmentsTearDownStart(*parent_); std::for_each(environments_.rbegin(), environments_.rend(), TearDownEnvironment); repeater->OnEnvironmentsTearDownEnd(*parent_); } elapsed_time_ = GetTimeInMillis() - start; // Tells the unit test event listener that the tests have just finished. repeater->OnTestIterationEnd(*parent_, i); // Gets the result and clears it. if (!Passed()) { failed = true; } // Restores the original test order after the iteration. This // allows the user to quickly repro a failure that happens in the // N-th iteration without repeating the first (N - 1) iterations. // This is not enclosed in "if (GTEST_FLAG(shuffle)) { ... }", in // case the user somehow changes the value of the flag somewhere // (it's always safe to unshuffle the tests). UnshuffleTests(); if (GTEST_FLAG(shuffle)) { // Picks a new random seed for each iteration. random_seed_ = GetNextRandomSeed(random_seed_); } } repeater->OnTestProgramEnd(*parent_); if (!gtest_is_initialized_before_run_all_tests) { ColoredPrintf( COLOR_RED, "\nIMPORTANT NOTICE - DO NOT IGNORE:\n" "This test program did NOT call " GTEST_INIT_GOOGLE_TEST_NAME_ "() before calling RUN_ALL_TESTS(). This is INVALID. Soon " GTEST_NAME_ " will start to enforce the valid usage. " "Please fix it ASAP, or IT WILL START TO FAIL.\n"); // NOLINT #if GTEST_FOR_GOOGLE_ ColoredPrintf(COLOR_RED, "For more details, see http://wiki/Main/ValidGUnitMain.\n"); #endif // GTEST_FOR_GOOGLE_ } return !failed; } // Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file // if the variable is present. If a file already exists at this location, this // function will write over it. If the variable is present, but the file cannot // be created, prints an error and exits. void WriteToShardStatusFileIfNeeded() { const char* const test_shard_file = posix::GetEnv(kTestShardStatusFile); if (test_shard_file != nullptr) { FILE* const file = posix::FOpen(test_shard_file, "w"); if (file == nullptr) { ColoredPrintf(COLOR_RED, "Could not write to the test shard status file \"%s\" " "specified by the %s environment variable.\n", test_shard_file, kTestShardStatusFile); fflush(stdout); exit(EXIT_FAILURE); } fclose(file); } } // Checks whether sharding is enabled by examining the relevant // environment variable values. If the variables are present, // but inconsistent (i.e., shard_index >= total_shards), prints // an error and exits. If in_subprocess_for_death_test, sharding is // disabled because it must only be applied to the original test // process. Otherwise, we could filter out death tests we intended to execute. bool ShouldShard(const char* total_shards_env, const char* shard_index_env, bool in_subprocess_for_death_test) { if (in_subprocess_for_death_test) { return false; } const Int32 total_shards = Int32FromEnvOrDie(total_shards_env, -1); const Int32 shard_index = Int32FromEnvOrDie(shard_index_env, -1); if (total_shards == -1 && shard_index == -1) { return false; } else if (total_shards == -1 && shard_index != -1) { const Message msg = Message() << "Invalid environment variables: you have " << kTestShardIndex << " = " << shard_index << ", but have left " << kTestTotalShards << " unset.\n"; ColoredPrintf(COLOR_RED, "%s", msg.GetString().c_str()); fflush(stdout); exit(EXIT_FAILURE); } else if (total_shards != -1 && shard_index == -1) { const Message msg = Message() << "Invalid environment variables: you have " << kTestTotalShards << " = " << total_shards << ", but have left " << kTestShardIndex << " unset.\n"; ColoredPrintf(COLOR_RED, "%s", msg.GetString().c_str()); fflush(stdout); exit(EXIT_FAILURE); } else if (shard_index < 0 || shard_index >= total_shards) { const Message msg = Message() << "Invalid environment variables: we require 0 <= " << kTestShardIndex << " < " << kTestTotalShards << ", but you have " << kTestShardIndex << "=" << shard_index << ", " << kTestTotalShards << "=" << total_shards << ".\n"; ColoredPrintf(COLOR_RED, "%s", msg.GetString().c_str()); fflush(stdout); exit(EXIT_FAILURE); } return total_shards > 1; } // Parses the environment variable var as an Int32. If it is unset, // returns default_val. If it is not an Int32, prints an error // and aborts. Int32 Int32FromEnvOrDie(const char* var, Int32 default_val) { const char* str_val = posix::GetEnv(var); if (str_val == nullptr) { return default_val; } Int32 result; if (!ParseInt32(Message() << "The value of environment variable " << var, str_val, &result)) { exit(EXIT_FAILURE); } return result; } // Given the total number of shards, the shard index, and the test id, // returns true if and only if the test should be run on this shard. The test id // is some arbitrary but unique non-negative integer assigned to each test // method. Assumes that 0 <= shard_index < total_shards. bool ShouldRunTestOnShard(int total_shards, int shard_index, int test_id) { return (test_id % total_shards) == shard_index; } // Compares the name of each test with the user-specified filter to // decide whether the test should be run, then records the result in // each TestSuite and TestInfo object. // If shard_tests == true, further filters tests based on sharding // variables in the environment - see // https://github.com/google/googletest/blob/master/googletest/docs/advanced.md // . Returns the number of tests that should run. int UnitTestImpl::FilterTests(ReactionToSharding shard_tests) { const Int32 total_shards = shard_tests == HONOR_SHARDING_PROTOCOL ? Int32FromEnvOrDie(kTestTotalShards, -1) : -1; const Int32 shard_index = shard_tests == HONOR_SHARDING_PROTOCOL ? Int32FromEnvOrDie(kTestShardIndex, -1) : -1; // num_runnable_tests are the number of tests that will // run across all shards (i.e., match filter and are not disabled). // num_selected_tests are the number of tests to be run on // this shard. int num_runnable_tests = 0; int num_selected_tests = 0; for (auto* test_suite : test_suites_) { const std::string& test_suite_name = test_suite->name(); test_suite->set_should_run(false); for (size_t j = 0; j < test_suite->test_info_list().size(); j++) { TestInfo* const test_info = test_suite->test_info_list()[j]; const std::string test_name(test_info->name()); // A test is disabled if test suite name or test name matches // kDisableTestFilter. const bool is_disabled = internal::UnitTestOptions::MatchesFilter( test_suite_name, kDisableTestFilter) || internal::UnitTestOptions::MatchesFilter( test_name, kDisableTestFilter); test_info->is_disabled_ = is_disabled; const bool matches_filter = internal::UnitTestOptions::FilterMatchesTest( test_suite_name, test_name); test_info->matches_filter_ = matches_filter; const bool is_runnable = (GTEST_FLAG(also_run_disabled_tests) || !is_disabled) && matches_filter; const bool is_in_another_shard = shard_tests != IGNORE_SHARDING_PROTOCOL && !ShouldRunTestOnShard(total_shards, shard_index, num_runnable_tests); test_info->is_in_another_shard_ = is_in_another_shard; const bool is_selected = is_runnable && !is_in_another_shard; num_runnable_tests += is_runnable; num_selected_tests += is_selected; test_info->should_run_ = is_selected; test_suite->set_should_run(test_suite->should_run() || is_selected); } } return num_selected_tests; } // Prints the given C-string on a single line by replacing all '\n' // characters with string "\\n". If the output takes more than // max_length characters, only prints the first max_length characters // and "...". static void PrintOnOneLine(const char* str, int max_length) { if (str != nullptr) { for (int i = 0; *str != '\0'; ++str) { if (i >= max_length) { printf("..."); break; } if (*str == '\n') { printf("\\n"); i += 2; } else { printf("%c", *str); ++i; } } } } // Prints the names of the tests matching the user-specified filter flag. void UnitTestImpl::ListTestsMatchingFilter() { // Print at most this many characters for each type/value parameter. const int kMaxParamLength = 250; for (auto* test_suite : test_suites_) { bool printed_test_suite_name = false; for (size_t j = 0; j < test_suite->test_info_list().size(); j++) { const TestInfo* const test_info = test_suite->test_info_list()[j]; if (test_info->matches_filter_) { if (!printed_test_suite_name) { printed_test_suite_name = true; printf("%s.", test_suite->name()); if (test_suite->type_param() != nullptr) { printf(" # %s = ", kTypeParamLabel); // We print the type parameter on a single line to make // the output easy to parse by a program. PrintOnOneLine(test_suite->type_param(), kMaxParamLength); } printf("\n"); } printf(" %s", test_info->name()); if (test_info->value_param() != nullptr) { printf(" # %s = ", kValueParamLabel); // We print the value parameter on a single line to make the // output easy to parse by a program. PrintOnOneLine(test_info->value_param(), kMaxParamLength); } printf("\n"); } } } fflush(stdout); const std::string& output_format = UnitTestOptions::GetOutputFormat(); if (output_format == "xml" || output_format == "json") { FILE* fileout = OpenFileForWriting( UnitTestOptions::GetAbsolutePathToOutputFile().c_str()); std::stringstream stream; if (output_format == "xml") { XmlUnitTestResultPrinter( UnitTestOptions::GetAbsolutePathToOutputFile().c_str()) .PrintXmlTestsList(&stream, test_suites_); } else if (output_format == "json") { JsonUnitTestResultPrinter( UnitTestOptions::GetAbsolutePathToOutputFile().c_str()) .PrintJsonTestList(&stream, test_suites_); } fprintf(fileout, "%s", StringStreamToString(&stream).c_str()); fclose(fileout); } } // Sets the OS stack trace getter. // // Does nothing if the input and the current OS stack trace getter are // the same; otherwise, deletes the old getter and makes the input the // current getter. void UnitTestImpl::set_os_stack_trace_getter( OsStackTraceGetterInterface* getter) { if (os_stack_trace_getter_ != getter) { delete os_stack_trace_getter_; os_stack_trace_getter_ = getter; } } // Returns the current OS stack trace getter if it is not NULL; // otherwise, creates an OsStackTraceGetter, makes it the current // getter, and returns it. OsStackTraceGetterInterface* UnitTestImpl::os_stack_trace_getter() { if (os_stack_trace_getter_ == nullptr) { #ifdef GTEST_OS_STACK_TRACE_GETTER_ os_stack_trace_getter_ = new GTEST_OS_STACK_TRACE_GETTER_; #else os_stack_trace_getter_ = new OsStackTraceGetter; #endif // GTEST_OS_STACK_TRACE_GETTER_ } return os_stack_trace_getter_; } // Returns the most specific TestResult currently running. TestResult* UnitTestImpl::current_test_result() { if (current_test_info_ != nullptr) { return ¤t_test_info_->result_; } if (current_test_suite_ != nullptr) { return ¤t_test_suite_->ad_hoc_test_result_; } return &ad_hoc_test_result_; } // Shuffles all test suites, and the tests within each test suite, // making sure that death tests are still run first. void UnitTestImpl::ShuffleTests() { // Shuffles the death test suites. ShuffleRange(random(), 0, last_death_test_suite_ + 1, &test_suite_indices_); // Shuffles the non-death test suites. ShuffleRange(random(), last_death_test_suite_ + 1, static_cast(test_suites_.size()), &test_suite_indices_); // Shuffles the tests inside each test suite. for (auto& test_suite : test_suites_) { test_suite->ShuffleTests(random()); } } // Restores the test suites and tests to their order before the first shuffle. void UnitTestImpl::UnshuffleTests() { for (size_t i = 0; i < test_suites_.size(); i++) { // Unshuffles the tests in each test suite. test_suites_[i]->UnshuffleTests(); // Resets the index of each test suite. test_suite_indices_[i] = static_cast(i); } } // Returns the current OS stack trace as an std::string. // // The maximum number of stack frames to be included is specified by // the gtest_stack_trace_depth flag. The skip_count parameter // specifies the number of top frames to be skipped, which doesn't // count against the number of frames to be included. // // For example, if Foo() calls Bar(), which in turn calls // GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in // the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. std::string GetCurrentOsStackTraceExceptTop(UnitTest* /*unit_test*/, int skip_count) { // We pass skip_count + 1 to skip this wrapper function in addition // to what the user really wants to skip. return GetUnitTestImpl()->CurrentOsStackTraceExceptTop(skip_count + 1); } // Used by the GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_ macro to // suppress unreachable code warnings. namespace { class ClassUniqueToAlwaysTrue {}; } bool IsTrue(bool condition) { return condition; } bool AlwaysTrue() { #if GTEST_HAS_EXCEPTIONS // This condition is always false so AlwaysTrue() never actually throws, // but it makes the compiler think that it may throw. if (IsTrue(false)) throw ClassUniqueToAlwaysTrue(); #endif // GTEST_HAS_EXCEPTIONS return true; } // If *pstr starts with the given prefix, modifies *pstr to be right // past the prefix and returns true; otherwise leaves *pstr unchanged // and returns false. None of pstr, *pstr, and prefix can be NULL. bool SkipPrefix(const char* prefix, const char** pstr) { const size_t prefix_len = strlen(prefix); if (strncmp(*pstr, prefix, prefix_len) == 0) { *pstr += prefix_len; return true; } return false; } // Parses a string as a command line flag. The string should have // the format "--flag=value". When def_optional is true, the "=value" // part can be omitted. // // Returns the value of the flag, or NULL if the parsing failed. static const char* ParseFlagValue(const char* str, const char* flag, bool def_optional) { // str and flag must not be NULL. if (str == nullptr || flag == nullptr) return nullptr; // The flag must start with "--" followed by GTEST_FLAG_PREFIX_. const std::string flag_str = std::string("--") + GTEST_FLAG_PREFIX_ + flag; const size_t flag_len = flag_str.length(); if (strncmp(str, flag_str.c_str(), flag_len) != 0) return nullptr; // Skips the flag name. const char* flag_end = str + flag_len; // When def_optional is true, it's OK to not have a "=value" part. if (def_optional && (flag_end[0] == '\0')) { return flag_end; } // If def_optional is true and there are more characters after the // flag name, or if def_optional is false, there must be a '=' after // the flag name. if (flag_end[0] != '=') return nullptr; // Returns the string after "=". return flag_end + 1; } // Parses a string for a bool flag, in the form of either // "--flag=value" or "--flag". // // In the former case, the value is taken as true as long as it does // not start with '0', 'f', or 'F'. // // In the latter case, the value is taken as true. // // On success, stores the value of the flag in *value, and returns // true. On failure, returns false without changing *value. static bool ParseBoolFlag(const char* str, const char* flag, bool* value) { // Gets the value of the flag as a string. const char* const value_str = ParseFlagValue(str, flag, true); // Aborts if the parsing failed. if (value_str == nullptr) return false; // Converts the string value to a bool. *value = !(*value_str == '0' || *value_str == 'f' || *value_str == 'F'); return true; } // Parses a string for an Int32 flag, in the form of // "--flag=value". // // On success, stores the value of the flag in *value, and returns // true. On failure, returns false without changing *value. bool ParseInt32Flag(const char* str, const char* flag, Int32* value) { // Gets the value of the flag as a string. const char* const value_str = ParseFlagValue(str, flag, false); // Aborts if the parsing failed. if (value_str == nullptr) return false; // Sets *value to the value of the flag. return ParseInt32(Message() << "The value of flag --" << flag, value_str, value); } // Parses a string for a string flag, in the form of // "--flag=value". // // On success, stores the value of the flag in *value, and returns // true. On failure, returns false without changing *value. template static bool ParseStringFlag(const char* str, const char* flag, String* value) { // Gets the value of the flag as a string. const char* const value_str = ParseFlagValue(str, flag, false); // Aborts if the parsing failed. if (value_str == nullptr) return false; // Sets *value to the value of the flag. *value = value_str; return true; } // Determines whether a string has a prefix that Google Test uses for its // flags, i.e., starts with GTEST_FLAG_PREFIX_ or GTEST_FLAG_PREFIX_DASH_. // If Google Test detects that a command line flag has its prefix but is not // recognized, it will print its help message. Flags starting with // GTEST_INTERNAL_PREFIX_ followed by "internal_" are considered Google Test // internal flags and do not trigger the help message. static bool HasGoogleTestFlagPrefix(const char* str) { return (SkipPrefix("--", &str) || SkipPrefix("-", &str) || SkipPrefix("/", &str)) && !SkipPrefix(GTEST_FLAG_PREFIX_ "internal_", &str) && (SkipPrefix(GTEST_FLAG_PREFIX_, &str) || SkipPrefix(GTEST_FLAG_PREFIX_DASH_, &str)); } // Prints a string containing code-encoded text. The following escape // sequences can be used in the string to control the text color: // // @@ prints a single '@' character. // @R changes the color to red. // @G changes the color to green. // @Y changes the color to yellow. // @D changes to the default terminal text color. // static void PrintColorEncoded(const char* str) { GTestColor color = COLOR_DEFAULT; // The current color. // Conceptually, we split the string into segments divided by escape // sequences. Then we print one segment at a time. At the end of // each iteration, the str pointer advances to the beginning of the // next segment. for (;;) { const char* p = strchr(str, '@'); if (p == nullptr) { ColoredPrintf(color, "%s", str); return; } ColoredPrintf(color, "%s", std::string(str, p).c_str()); const char ch = p[1]; str = p + 2; if (ch == '@') { ColoredPrintf(color, "@"); } else if (ch == 'D') { color = COLOR_DEFAULT; } else if (ch == 'R') { color = COLOR_RED; } else if (ch == 'G') { color = COLOR_GREEN; } else if (ch == 'Y') { color = COLOR_YELLOW; } else { --str; } } } static const char kColorEncodedHelpMessage[] = "This program contains tests written using " GTEST_NAME_ ". You can use the\n" "following command line flags to control its behavior:\n" "\n" "Test Selection:\n" " @G--" GTEST_FLAG_PREFIX_ "list_tests@D\n" " List the names of all tests instead of running them. The name of\n" " TEST(Foo, Bar) is \"Foo.Bar\".\n" " @G--" GTEST_FLAG_PREFIX_ "filter=@YPOSTIVE_PATTERNS" "[@G-@YNEGATIVE_PATTERNS]@D\n" " Run only the tests whose name matches one of the positive patterns but\n" " none of the negative patterns. '?' matches any single character; '*'\n" " matches any substring; ':' separates two patterns.\n" " @G--" GTEST_FLAG_PREFIX_ "also_run_disabled_tests@D\n" " Run all disabled tests too.\n" "\n" "Test Execution:\n" " @G--" GTEST_FLAG_PREFIX_ "repeat=@Y[COUNT]@D\n" " Run the tests repeatedly; use a negative count to repeat forever.\n" " @G--" GTEST_FLAG_PREFIX_ "shuffle@D\n" " Randomize tests' orders on every iteration.\n" " @G--" GTEST_FLAG_PREFIX_ "random_seed=@Y[NUMBER]@D\n" " Random number seed to use for shuffling test orders (between 1 and\n" " 99999, or 0 to use a seed based on the current time).\n" "\n" "Test Output:\n" " @G--" GTEST_FLAG_PREFIX_ "color=@Y(@Gyes@Y|@Gno@Y|@Gauto@Y)@D\n" " Enable/disable colored output. The default is @Gauto@D.\n" " -@G-" GTEST_FLAG_PREFIX_ "print_time=0@D\n" " Don't print the elapsed time of each test.\n" " @G--" GTEST_FLAG_PREFIX_ "output=@Y(@Gjson@Y|@Gxml@Y)[@G:@YDIRECTORY_PATH@G" GTEST_PATH_SEP_ "@Y|@G:@YFILE_PATH]@D\n" " Generate a JSON or XML report in the given directory or with the given\n" " file name. @YFILE_PATH@D defaults to @Gtest_detail.xml@D.\n" # if GTEST_CAN_STREAM_RESULTS_ " @G--" GTEST_FLAG_PREFIX_ "stream_result_to=@YHOST@G:@YPORT@D\n" " Stream test results to the given server.\n" # endif // GTEST_CAN_STREAM_RESULTS_ " @G--" GTEST_FLAG_PREFIX_ "print_skipped@D\n" " List all the skipped tests names in the summary\n" "\n" "Assertion Behavior:\n" # if GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS " @G--" GTEST_FLAG_PREFIX_ "death_test_style=@Y(@Gfast@Y|@Gthreadsafe@Y)@D\n" " Set the default death test style.\n" # endif // GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS " @G--" GTEST_FLAG_PREFIX_ "break_on_failure@D\n" " Turn assertion failures into debugger break-points.\n" " @G--" GTEST_FLAG_PREFIX_ "throw_on_failure@D\n" " Turn assertion failures into C++ exceptions for use by an external\n" " test framework.\n" " @G--" GTEST_FLAG_PREFIX_ "catch_exceptions=0@D\n" " Do not report exceptions as test failures. Instead, allow them\n" " to crash the program or throw a pop-up (on Windows).\n" "\n" "Except for @G--" GTEST_FLAG_PREFIX_ "list_tests@D, you can alternatively set " "the corresponding\n" "environment variable of a flag (all letters in upper-case). For example, to\n" "disable colored text output, you can either specify @G--" GTEST_FLAG_PREFIX_ "color=no@D or set\n" "the @G" GTEST_FLAG_PREFIX_UPPER_ "COLOR@D environment variable to @Gno@D.\n" "\n" "For more information, please read the " GTEST_NAME_ " documentation at\n" "@G" GTEST_PROJECT_URL_ "@D. If you find a bug in " GTEST_NAME_ "\n" "(not one in your own code or tests), please report it to\n" "@G<" GTEST_DEV_EMAIL_ ">@D.\n"; static bool ParseGoogleTestFlag(const char* const arg) { return ParseBoolFlag(arg, kAlsoRunDisabledTestsFlag, >EST_FLAG(also_run_disabled_tests)) || ParseBoolFlag(arg, kBreakOnFailureFlag, >EST_FLAG(break_on_failure)) || ParseBoolFlag(arg, kCatchExceptionsFlag, >EST_FLAG(catch_exceptions)) || ParseStringFlag(arg, kColorFlag, >EST_FLAG(color)) || ParseStringFlag(arg, kDeathTestStyleFlag, >EST_FLAG(death_test_style)) || ParseBoolFlag(arg, kDeathTestUseFork, >EST_FLAG(death_test_use_fork)) || ParseStringFlag(arg, kFilterFlag, >EST_FLAG(filter)) || ParseStringFlag(arg, kInternalRunDeathTestFlag, >EST_FLAG(internal_run_death_test)) || ParseBoolFlag(arg, kListTestsFlag, >EST_FLAG(list_tests)) || ParseStringFlag(arg, kOutputFlag, >EST_FLAG(output)) || ParseBoolFlag(arg, kPrintTimeFlag, >EST_FLAG(print_time)) || ParseBoolFlag(arg, kPrintUTF8Flag, >EST_FLAG(print_utf8)) || ParseInt32Flag(arg, kRandomSeedFlag, >EST_FLAG(random_seed)) || ParseInt32Flag(arg, kRepeatFlag, >EST_FLAG(repeat)) || ParseBoolFlag(arg, kShuffleFlag, >EST_FLAG(shuffle)) || ParseInt32Flag(arg, kStackTraceDepthFlag, >EST_FLAG(stack_trace_depth)) || ParseStringFlag(arg, kStreamResultToFlag, >EST_FLAG(stream_result_to)) || ParseBoolFlag(arg, kThrowOnFailureFlag, >EST_FLAG(throw_on_failure)) || ParseBoolFlag(arg, kPrintSkippedFlag, >EST_FLAG(print_skipped)); } #if GTEST_USE_OWN_FLAGFILE_FLAG_ static void LoadFlagsFromFile(const std::string& path) { FILE* flagfile = posix::FOpen(path.c_str(), "r"); if (!flagfile) { GTEST_LOG_(FATAL) << "Unable to open file \"" << GTEST_FLAG(flagfile) << "\""; } std::string contents(ReadEntireFile(flagfile)); posix::FClose(flagfile); std::vector lines; SplitString(contents, '\n', &lines); for (size_t i = 0; i < lines.size(); ++i) { if (lines[i].empty()) continue; if (!ParseGoogleTestFlag(lines[i].c_str())) g_help_flag = true; } } #endif // GTEST_USE_OWN_FLAGFILE_FLAG_ // Parses the command line for Google Test flags, without initializing // other parts of Google Test. The type parameter CharType can be // instantiated to either char or wchar_t. template void ParseGoogleTestFlagsOnlyImpl(int* argc, CharType** argv) { for (int i = 1; i < *argc; i++) { const std::string arg_string = StreamableToString(argv[i]); const char* const arg = arg_string.c_str(); using internal::ParseBoolFlag; using internal::ParseInt32Flag; using internal::ParseStringFlag; bool remove_flag = false; if (ParseGoogleTestFlag(arg)) { remove_flag = true; #if GTEST_USE_OWN_FLAGFILE_FLAG_ } else if (ParseStringFlag(arg, kFlagfileFlag, >EST_FLAG(flagfile))) { LoadFlagsFromFile(GTEST_FLAG(flagfile)); remove_flag = true; #endif // GTEST_USE_OWN_FLAGFILE_FLAG_ } else if (arg_string == "--help" || arg_string == "-h" || arg_string == "-?" || arg_string == "/?" || HasGoogleTestFlagPrefix(arg)) { // Both help flag and unrecognized Google Test flags (excluding // internal ones) trigger help display. g_help_flag = true; } if (remove_flag) { // Shift the remainder of the argv list left by one. Note // that argv has (*argc + 1) elements, the last one always being // NULL. The following loop moves the trailing NULL element as // well. for (int j = i; j != *argc; j++) { argv[j] = argv[j + 1]; } // Decrements the argument count. (*argc)--; // We also need to decrement the iterator as we just removed // an element. i--; } } if (g_help_flag) { // We print the help here instead of in RUN_ALL_TESTS(), as the // latter may not be called at all if the user is using Google // Test with another testing framework. PrintColorEncoded(kColorEncodedHelpMessage); } } // Parses the command line for Google Test flags, without initializing // other parts of Google Test. void ParseGoogleTestFlagsOnly(int* argc, char** argv) { ParseGoogleTestFlagsOnlyImpl(argc, argv); // Fix the value of *_NSGetArgc() on macOS, but if and only if // *_NSGetArgv() == argv // Only applicable to char** version of argv #if GTEST_OS_MAC #ifndef GTEST_OS_IOS if (*_NSGetArgv() == argv) { *_NSGetArgc() = *argc; } #endif #endif } void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv) { ParseGoogleTestFlagsOnlyImpl(argc, argv); } // The internal implementation of InitGoogleTest(). // // The type parameter CharType can be instantiated to either char or // wchar_t. template void InitGoogleTestImpl(int* argc, CharType** argv) { // We don't want to run the initialization code twice. if (GTestIsInitialized()) return; if (*argc <= 0) return; g_argvs.clear(); for (int i = 0; i != *argc; i++) { g_argvs.push_back(StreamableToString(argv[i])); } #if GTEST_HAS_ABSL absl::InitializeSymbolizer(g_argvs[0].c_str()); #endif // GTEST_HAS_ABSL ParseGoogleTestFlagsOnly(argc, argv); GetUnitTestImpl()->PostFlagParsingInit(); } } // namespace internal // Initializes Google Test. This must be called before calling // RUN_ALL_TESTS(). In particular, it parses a command line for the // flags that Google Test recognizes. Whenever a Google Test flag is // seen, it is removed from argv, and *argc is decremented. // // No value is returned. Instead, the Google Test flag variables are // updated. // // Calling the function for the second time has no user-visible effect. void InitGoogleTest(int* argc, char** argv) { #if defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_(argc, argv); #else // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) internal::InitGoogleTestImpl(argc, argv); #endif // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) } // This overloaded version can be used in Windows programs compiled in // UNICODE mode. void InitGoogleTest(int* argc, wchar_t** argv) { #if defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_(argc, argv); #else // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) internal::InitGoogleTestImpl(argc, argv); #endif // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) } // This overloaded version can be used on Arduino/embedded platforms where // there is no argc/argv. void InitGoogleTest() { // Since Arduino doesn't have a command line, fake out the argc/argv arguments int argc = 1; const auto arg0 = "dummy"; char* argv0 = const_cast(arg0); char** argv = &argv0; #if defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_(&argc, argv); #else // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) internal::InitGoogleTestImpl(&argc, argv); #endif // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) } std::string TempDir() { #if defined(GTEST_CUSTOM_TEMPDIR_FUNCTION_) return GTEST_CUSTOM_TEMPDIR_FUNCTION_(); #endif #if GTEST_OS_WINDOWS_MOBILE return "\\temp\\"; #elif GTEST_OS_WINDOWS const char* temp_dir = internal::posix::GetEnv("TEMP"); if (temp_dir == nullptr || temp_dir[0] == '\0') return "\\temp\\"; else if (temp_dir[strlen(temp_dir) - 1] == '\\') return temp_dir; else return std::string(temp_dir) + "\\"; #elif GTEST_OS_LINUX_ANDROID return "/sdcard/"; #else return "/tmp/"; #endif // GTEST_OS_WINDOWS_MOBILE } // Class ScopedTrace // Pushes the given source file location and message onto a per-thread // trace stack maintained by Google Test. void ScopedTrace::PushTrace(const char* file, int line, std::string message) { internal::TraceInfo trace; trace.file = file; trace.line = line; trace.message.swap(message); UnitTest::GetInstance()->PushGTestTrace(trace); } // Pops the info pushed by the c'tor. ScopedTrace::~ScopedTrace() GTEST_LOCK_EXCLUDED_(&UnitTest::mutex_) { UnitTest::GetInstance()->PopGTestTrace(); } } // namespace testing // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // This file implements death tests. #include #if GTEST_HAS_DEATH_TEST # if GTEST_OS_MAC # include # endif // GTEST_OS_MAC # include # include # include # if GTEST_OS_LINUX # include # endif // GTEST_OS_LINUX # include # if GTEST_OS_WINDOWS # include # else # include # include # endif // GTEST_OS_WINDOWS # if GTEST_OS_QNX # include # endif // GTEST_OS_QNX # if GTEST_OS_FUCHSIA # include # include # include # include # include # include # include # include # include # include # include # endif // GTEST_OS_FUCHSIA #endif // GTEST_HAS_DEATH_TEST namespace testing { // Constants. // The default death test style. // // This is defined in internal/gtest-port.h as "fast", but can be overridden by // a definition in internal/custom/gtest-port.h. The recommended value, which is // used internally at Google, is "threadsafe". static const char kDefaultDeathTestStyle[] = GTEST_DEFAULT_DEATH_TEST_STYLE; GTEST_DEFINE_string_( death_test_style, internal::StringFromGTestEnv("death_test_style", kDefaultDeathTestStyle), "Indicates how to run a death test in a forked child process: " "\"threadsafe\" (child process re-executes the test binary " "from the beginning, running only the specific death test) or " "\"fast\" (child process runs the death test immediately " "after forking)."); GTEST_DEFINE_bool_( death_test_use_fork, internal::BoolFromGTestEnv("death_test_use_fork", false), "Instructs to use fork()/_exit() instead of clone() in death tests. " "Ignored and always uses fork() on POSIX systems where clone() is not " "implemented. Useful when running under valgrind or similar tools if " "those do not support clone(). Valgrind 3.3.1 will just fail if " "it sees an unsupported combination of clone() flags. " "It is not recommended to use this flag w/o valgrind though it will " "work in 99% of the cases. Once valgrind is fixed, this flag will " "most likely be removed."); namespace internal { GTEST_DEFINE_string_( internal_run_death_test, "", "Indicates the file, line number, temporal index of " "the single death test to run, and a file descriptor to " "which a success code may be sent, all separated by " "the '|' characters. This flag is specified if and only if the " "current process is a sub-process launched for running a thread-safe " "death test. FOR INTERNAL USE ONLY."); } // namespace internal #if GTEST_HAS_DEATH_TEST namespace internal { // Valid only for fast death tests. Indicates the code is running in the // child process of a fast style death test. # if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA static bool g_in_fast_death_test_child = false; # endif // Returns a Boolean value indicating whether the caller is currently // executing in the context of the death test child process. Tools such as // Valgrind heap checkers may need this to modify their behavior in death // tests. IMPORTANT: This is an internal utility. Using it may break the // implementation of death tests. User code MUST NOT use it. bool InDeathTestChild() { # if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA // On Windows and Fuchsia, death tests are thread-safe regardless of the value // of the death_test_style flag. return !GTEST_FLAG(internal_run_death_test).empty(); # else if (GTEST_FLAG(death_test_style) == "threadsafe") return !GTEST_FLAG(internal_run_death_test).empty(); else return g_in_fast_death_test_child; #endif } } // namespace internal // ExitedWithCode constructor. ExitedWithCode::ExitedWithCode(int exit_code) : exit_code_(exit_code) { } // ExitedWithCode function-call operator. bool ExitedWithCode::operator()(int exit_status) const { # if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA return exit_status == exit_code_; # else return WIFEXITED(exit_status) && WEXITSTATUS(exit_status) == exit_code_; # endif // GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA } # if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA // KilledBySignal constructor. KilledBySignal::KilledBySignal(int signum) : signum_(signum) { } // KilledBySignal function-call operator. bool KilledBySignal::operator()(int exit_status) const { # if defined(GTEST_KILLED_BY_SIGNAL_OVERRIDE_) { bool result; if (GTEST_KILLED_BY_SIGNAL_OVERRIDE_(signum_, exit_status, &result)) { return result; } } # endif // defined(GTEST_KILLED_BY_SIGNAL_OVERRIDE_) return WIFSIGNALED(exit_status) && WTERMSIG(exit_status) == signum_; } # endif // !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA namespace internal { // Utilities needed for death tests. // Generates a textual description of a given exit code, in the format // specified by wait(2). static std::string ExitSummary(int exit_code) { Message m; # if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA m << "Exited with exit status " << exit_code; # else if (WIFEXITED(exit_code)) { m << "Exited with exit status " << WEXITSTATUS(exit_code); } else if (WIFSIGNALED(exit_code)) { m << "Terminated by signal " << WTERMSIG(exit_code); } # ifdef WCOREDUMP if (WCOREDUMP(exit_code)) { m << " (core dumped)"; } # endif # endif // GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA return m.GetString(); } // Returns true if exit_status describes a process that was terminated // by a signal, or exited normally with a nonzero exit code. bool ExitedUnsuccessfully(int exit_status) { return !ExitedWithCode(0)(exit_status); } # if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA // Generates a textual failure message when a death test finds more than // one thread running, or cannot determine the number of threads, prior // to executing the given statement. It is the responsibility of the // caller not to pass a thread_count of 1. static std::string DeathTestThreadWarning(size_t thread_count) { Message msg; msg << "Death tests use fork(), which is unsafe particularly" << " in a threaded context. For this test, " << GTEST_NAME_ << " "; if (thread_count == 0) { msg << "couldn't detect the number of threads."; } else { msg << "detected " << thread_count << " threads."; } msg << " See " "https://github.com/google/googletest/blob/master/googletest/docs/" "advanced.md#death-tests-and-threads" << " for more explanation and suggested solutions, especially if" << " this is the last message you see before your test times out."; return msg.GetString(); } # endif // !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA // Flag characters for reporting a death test that did not die. static const char kDeathTestLived = 'L'; static const char kDeathTestReturned = 'R'; static const char kDeathTestThrew = 'T'; static const char kDeathTestInternalError = 'I'; #if GTEST_OS_FUCHSIA // File descriptor used for the pipe in the child process. static const int kFuchsiaReadPipeFd = 3; #endif // An enumeration describing all of the possible ways that a death test can // conclude. DIED means that the process died while executing the test // code; LIVED means that process lived beyond the end of the test code; // RETURNED means that the test statement attempted to execute a return // statement, which is not allowed; THREW means that the test statement // returned control by throwing an exception. IN_PROGRESS means the test // has not yet concluded. enum DeathTestOutcome { IN_PROGRESS, DIED, LIVED, RETURNED, THREW }; // Routine for aborting the program which is safe to call from an // exec-style death test child process, in which case the error // message is propagated back to the parent process. Otherwise, the // message is simply printed to stderr. In either case, the program // then exits with status 1. static void DeathTestAbort(const std::string& message) { // On a POSIX system, this function may be called from a threadsafe-style // death test child process, which operates on a very small stack. Use // the heap for any additional non-minuscule memory requirements. const InternalRunDeathTestFlag* const flag = GetUnitTestImpl()->internal_run_death_test_flag(); if (flag != nullptr) { FILE* parent = posix::FDOpen(flag->write_fd(), "w"); fputc(kDeathTestInternalError, parent); fprintf(parent, "%s", message.c_str()); fflush(parent); _exit(1); } else { fprintf(stderr, "%s", message.c_str()); fflush(stderr); posix::Abort(); } } // A replacement for CHECK that calls DeathTestAbort if the assertion // fails. # define GTEST_DEATH_TEST_CHECK_(expression) \ do { \ if (!::testing::internal::IsTrue(expression)) { \ DeathTestAbort( \ ::std::string("CHECK failed: File ") + __FILE__ + ", line " \ + ::testing::internal::StreamableToString(__LINE__) + ": " \ + #expression); \ } \ } while (::testing::internal::AlwaysFalse()) // This macro is similar to GTEST_DEATH_TEST_CHECK_, but it is meant for // evaluating any system call that fulfills two conditions: it must return // -1 on failure, and set errno to EINTR when it is interrupted and // should be tried again. The macro expands to a loop that repeatedly // evaluates the expression as long as it evaluates to -1 and sets // errno to EINTR. If the expression evaluates to -1 but errno is // something other than EINTR, DeathTestAbort is called. # define GTEST_DEATH_TEST_CHECK_SYSCALL_(expression) \ do { \ int gtest_retval; \ do { \ gtest_retval = (expression); \ } while (gtest_retval == -1 && errno == EINTR); \ if (gtest_retval == -1) { \ DeathTestAbort( \ ::std::string("CHECK failed: File ") + __FILE__ + ", line " \ + ::testing::internal::StreamableToString(__LINE__) + ": " \ + #expression + " != -1"); \ } \ } while (::testing::internal::AlwaysFalse()) // Returns the message describing the last system error in errno. std::string GetLastErrnoDescription() { return errno == 0 ? "" : posix::StrError(errno); } // This is called from a death test parent process to read a failure // message from the death test child process and log it with the FATAL // severity. On Windows, the message is read from a pipe handle. On other // platforms, it is read from a file descriptor. static void FailFromInternalError(int fd) { Message error; char buffer[256]; int num_read; do { while ((num_read = posix::Read(fd, buffer, 255)) > 0) { buffer[num_read] = '\0'; error << buffer; } } while (num_read == -1 && errno == EINTR); if (num_read == 0) { GTEST_LOG_(FATAL) << error.GetString(); } else { const int last_error = errno; GTEST_LOG_(FATAL) << "Error while reading death test internal: " << GetLastErrnoDescription() << " [" << last_error << "]"; } } // Death test constructor. Increments the running death test count // for the current test. DeathTest::DeathTest() { TestInfo* const info = GetUnitTestImpl()->current_test_info(); if (info == nullptr) { DeathTestAbort("Cannot run a death test outside of a TEST or " "TEST_F construct"); } } // Creates and returns a death test by dispatching to the current // death test factory. bool DeathTest::Create(const char* statement, Matcher matcher, const char* file, int line, DeathTest** test) { return GetUnitTestImpl()->death_test_factory()->Create( statement, std::move(matcher), file, line, test); } const char* DeathTest::LastMessage() { return last_death_test_message_.c_str(); } void DeathTest::set_last_death_test_message(const std::string& message) { last_death_test_message_ = message; } std::string DeathTest::last_death_test_message_; // Provides cross platform implementation for some death functionality. class DeathTestImpl : public DeathTest { protected: DeathTestImpl(const char* a_statement, Matcher matcher) : statement_(a_statement), matcher_(std::move(matcher)), spawned_(false), status_(-1), outcome_(IN_PROGRESS), read_fd_(-1), write_fd_(-1) {} // read_fd_ is expected to be closed and cleared by a derived class. ~DeathTestImpl() override { GTEST_DEATH_TEST_CHECK_(read_fd_ == -1); } void Abort(AbortReason reason) override; bool Passed(bool status_ok) override; const char* statement() const { return statement_; } bool spawned() const { return spawned_; } void set_spawned(bool is_spawned) { spawned_ = is_spawned; } int status() const { return status_; } void set_status(int a_status) { status_ = a_status; } DeathTestOutcome outcome() const { return outcome_; } void set_outcome(DeathTestOutcome an_outcome) { outcome_ = an_outcome; } int read_fd() const { return read_fd_; } void set_read_fd(int fd) { read_fd_ = fd; } int write_fd() const { return write_fd_; } void set_write_fd(int fd) { write_fd_ = fd; } // Called in the parent process only. Reads the result code of the death // test child process via a pipe, interprets it to set the outcome_ // member, and closes read_fd_. Outputs diagnostics and terminates in // case of unexpected codes. void ReadAndInterpretStatusByte(); // Returns stderr output from the child process. virtual std::string GetErrorLogs(); private: // The textual content of the code this object is testing. This class // doesn't own this string and should not attempt to delete it. const char* const statement_; // A matcher that's expected to match the stderr output by the child process. Matcher matcher_; // True if the death test child process has been successfully spawned. bool spawned_; // The exit status of the child process. int status_; // How the death test concluded. DeathTestOutcome outcome_; // Descriptor to the read end of the pipe to the child process. It is // always -1 in the child process. The child keeps its write end of the // pipe in write_fd_. int read_fd_; // Descriptor to the child's write end of the pipe to the parent process. // It is always -1 in the parent process. The parent keeps its end of the // pipe in read_fd_. int write_fd_; }; // Called in the parent process only. Reads the result code of the death // test child process via a pipe, interprets it to set the outcome_ // member, and closes read_fd_. Outputs diagnostics and terminates in // case of unexpected codes. void DeathTestImpl::ReadAndInterpretStatusByte() { char flag; int bytes_read; // The read() here blocks until data is available (signifying the // failure of the death test) or until the pipe is closed (signifying // its success), so it's okay to call this in the parent before // the child process has exited. do { bytes_read = posix::Read(read_fd(), &flag, 1); } while (bytes_read == -1 && errno == EINTR); if (bytes_read == 0) { set_outcome(DIED); } else if (bytes_read == 1) { switch (flag) { case kDeathTestReturned: set_outcome(RETURNED); break; case kDeathTestThrew: set_outcome(THREW); break; case kDeathTestLived: set_outcome(LIVED); break; case kDeathTestInternalError: FailFromInternalError(read_fd()); // Does not return. break; default: GTEST_LOG_(FATAL) << "Death test child process reported " << "unexpected status byte (" << static_cast(flag) << ")"; } } else { GTEST_LOG_(FATAL) << "Read from death test child process failed: " << GetLastErrnoDescription(); } GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Close(read_fd())); set_read_fd(-1); } std::string DeathTestImpl::GetErrorLogs() { return GetCapturedStderr(); } // Signals that the death test code which should have exited, didn't. // Should be called only in a death test child process. // Writes a status byte to the child's status file descriptor, then // calls _exit(1). void DeathTestImpl::Abort(AbortReason reason) { // The parent process considers the death test to be a failure if // it finds any data in our pipe. So, here we write a single flag byte // to the pipe, then exit. const char status_ch = reason == TEST_DID_NOT_DIE ? kDeathTestLived : reason == TEST_THREW_EXCEPTION ? kDeathTestThrew : kDeathTestReturned; GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Write(write_fd(), &status_ch, 1)); // We are leaking the descriptor here because on some platforms (i.e., // when built as Windows DLL), destructors of global objects will still // run after calling _exit(). On such systems, write_fd_ will be // indirectly closed from the destructor of UnitTestImpl, causing double // close if it is also closed here. On debug configurations, double close // may assert. As there are no in-process buffers to flush here, we are // relying on the OS to close the descriptor after the process terminates // when the destructors are not run. _exit(1); // Exits w/o any normal exit hooks (we were supposed to crash) } // Returns an indented copy of stderr output for a death test. // This makes distinguishing death test output lines from regular log lines // much easier. static ::std::string FormatDeathTestOutput(const ::std::string& output) { ::std::string ret; for (size_t at = 0; ; ) { const size_t line_end = output.find('\n', at); ret += "[ DEATH ] "; if (line_end == ::std::string::npos) { ret += output.substr(at); break; } ret += output.substr(at, line_end + 1 - at); at = line_end + 1; } return ret; } // Assesses the success or failure of a death test, using both private // members which have previously been set, and one argument: // // Private data members: // outcome: An enumeration describing how the death test // concluded: DIED, LIVED, THREW, or RETURNED. The death test // fails in the latter three cases. // status: The exit status of the child process. On *nix, it is in the // in the format specified by wait(2). On Windows, this is the // value supplied to the ExitProcess() API or a numeric code // of the exception that terminated the program. // matcher_: A matcher that's expected to match the stderr output by the child // process. // // Argument: // status_ok: true if exit_status is acceptable in the context of // this particular death test, which fails if it is false // // Returns true if and only if all of the above conditions are met. Otherwise, // the first failing condition, in the order given above, is the one that is // reported. Also sets the last death test message string. bool DeathTestImpl::Passed(bool status_ok) { if (!spawned()) return false; const std::string error_message = GetErrorLogs(); bool success = false; Message buffer; buffer << "Death test: " << statement() << "\n"; switch (outcome()) { case LIVED: buffer << " Result: failed to die.\n" << " Error msg:\n" << FormatDeathTestOutput(error_message); break; case THREW: buffer << " Result: threw an exception.\n" << " Error msg:\n" << FormatDeathTestOutput(error_message); break; case RETURNED: buffer << " Result: illegal return in test statement.\n" << " Error msg:\n" << FormatDeathTestOutput(error_message); break; case DIED: if (status_ok) { if (matcher_.Matches(error_message)) { success = true; } else { std::ostringstream stream; matcher_.DescribeTo(&stream); buffer << " Result: died but not with expected error.\n" << " Expected: " << stream.str() << "\n" << "Actual msg:\n" << FormatDeathTestOutput(error_message); } } else { buffer << " Result: died but not with expected exit code:\n" << " " << ExitSummary(status()) << "\n" << "Actual msg:\n" << FormatDeathTestOutput(error_message); } break; case IN_PROGRESS: default: GTEST_LOG_(FATAL) << "DeathTest::Passed somehow called before conclusion of test"; } DeathTest::set_last_death_test_message(buffer.GetString()); return success; } # if GTEST_OS_WINDOWS // WindowsDeathTest implements death tests on Windows. Due to the // specifics of starting new processes on Windows, death tests there are // always threadsafe, and Google Test considers the // --gtest_death_test_style=fast setting to be equivalent to // --gtest_death_test_style=threadsafe there. // // A few implementation notes: Like the Linux version, the Windows // implementation uses pipes for child-to-parent communication. But due to // the specifics of pipes on Windows, some extra steps are required: // // 1. The parent creates a communication pipe and stores handles to both // ends of it. // 2. The parent starts the child and provides it with the information // necessary to acquire the handle to the write end of the pipe. // 3. The child acquires the write end of the pipe and signals the parent // using a Windows event. // 4. Now the parent can release the write end of the pipe on its side. If // this is done before step 3, the object's reference count goes down to // 0 and it is destroyed, preventing the child from acquiring it. The // parent now has to release it, or read operations on the read end of // the pipe will not return when the child terminates. // 5. The parent reads child's output through the pipe (outcome code and // any possible error messages) from the pipe, and its stderr and then // determines whether to fail the test. // // Note: to distinguish Win32 API calls from the local method and function // calls, the former are explicitly resolved in the global namespace. // class WindowsDeathTest : public DeathTestImpl { public: WindowsDeathTest(const char* a_statement, Matcher matcher, const char* file, int line) : DeathTestImpl(a_statement, std::move(matcher)), file_(file), line_(line) {} // All of these virtual functions are inherited from DeathTest. virtual int Wait(); virtual TestRole AssumeRole(); private: // The name of the file in which the death test is located. const char* const file_; // The line number on which the death test is located. const int line_; // Handle to the write end of the pipe to the child process. AutoHandle write_handle_; // Child process handle. AutoHandle child_handle_; // Event the child process uses to signal the parent that it has // acquired the handle to the write end of the pipe. After seeing this // event the parent can release its own handles to make sure its // ReadFile() calls return when the child terminates. AutoHandle event_handle_; }; // Waits for the child in a death test to exit, returning its exit // status, or 0 if no child process exists. As a side effect, sets the // outcome data member. int WindowsDeathTest::Wait() { if (!spawned()) return 0; // Wait until the child either signals that it has acquired the write end // of the pipe or it dies. const HANDLE wait_handles[2] = { child_handle_.Get(), event_handle_.Get() }; switch (::WaitForMultipleObjects(2, wait_handles, FALSE, // Waits for any of the handles. INFINITE)) { case WAIT_OBJECT_0: case WAIT_OBJECT_0 + 1: break; default: GTEST_DEATH_TEST_CHECK_(false); // Should not get here. } // The child has acquired the write end of the pipe or exited. // We release the handle on our side and continue. write_handle_.Reset(); event_handle_.Reset(); ReadAndInterpretStatusByte(); // Waits for the child process to exit if it haven't already. This // returns immediately if the child has already exited, regardless of // whether previous calls to WaitForMultipleObjects synchronized on this // handle or not. GTEST_DEATH_TEST_CHECK_( WAIT_OBJECT_0 == ::WaitForSingleObject(child_handle_.Get(), INFINITE)); DWORD status_code; GTEST_DEATH_TEST_CHECK_( ::GetExitCodeProcess(child_handle_.Get(), &status_code) != FALSE); child_handle_.Reset(); set_status(static_cast(status_code)); return status(); } // The AssumeRole process for a Windows death test. It creates a child // process with the same executable as the current process to run the // death test. The child process is given the --gtest_filter and // --gtest_internal_run_death_test flags such that it knows to run the // current death test only. DeathTest::TestRole WindowsDeathTest::AssumeRole() { const UnitTestImpl* const impl = GetUnitTestImpl(); const InternalRunDeathTestFlag* const flag = impl->internal_run_death_test_flag(); const TestInfo* const info = impl->current_test_info(); const int death_test_index = info->result()->death_test_count(); if (flag != nullptr) { // ParseInternalRunDeathTestFlag() has performed all the necessary // processing. set_write_fd(flag->write_fd()); return EXECUTE_TEST; } // WindowsDeathTest uses an anonymous pipe to communicate results of // a death test. SECURITY_ATTRIBUTES handles_are_inheritable = {sizeof(SECURITY_ATTRIBUTES), nullptr, TRUE}; HANDLE read_handle, write_handle; GTEST_DEATH_TEST_CHECK_( ::CreatePipe(&read_handle, &write_handle, &handles_are_inheritable, 0) // Default buffer size. != FALSE); set_read_fd(::_open_osfhandle(reinterpret_cast(read_handle), O_RDONLY)); write_handle_.Reset(write_handle); event_handle_.Reset(::CreateEvent( &handles_are_inheritable, TRUE, // The event will automatically reset to non-signaled state. FALSE, // The initial state is non-signalled. nullptr)); // The even is unnamed. GTEST_DEATH_TEST_CHECK_(event_handle_.Get() != nullptr); const std::string filter_flag = std::string("--") + GTEST_FLAG_PREFIX_ + kFilterFlag + "=" + info->test_suite_name() + "." + info->name(); const std::string internal_flag = std::string("--") + GTEST_FLAG_PREFIX_ + kInternalRunDeathTestFlag + "=" + file_ + "|" + StreamableToString(line_) + "|" + StreamableToString(death_test_index) + "|" + StreamableToString(static_cast(::GetCurrentProcessId())) + // size_t has the same width as pointers on both 32-bit and 64-bit // Windows platforms. // See http://msdn.microsoft.com/en-us/library/tcxf1dw6.aspx. "|" + StreamableToString(reinterpret_cast(write_handle)) + "|" + StreamableToString(reinterpret_cast(event_handle_.Get())); char executable_path[_MAX_PATH + 1]; // NOLINT GTEST_DEATH_TEST_CHECK_(_MAX_PATH + 1 != ::GetModuleFileNameA(nullptr, executable_path, _MAX_PATH)); std::string command_line = std::string(::GetCommandLineA()) + " " + filter_flag + " \"" + internal_flag + "\""; DeathTest::set_last_death_test_message(""); CaptureStderr(); // Flush the log buffers since the log streams are shared with the child. FlushInfoLog(); // The child process will share the standard handles with the parent. STARTUPINFOA startup_info; memset(&startup_info, 0, sizeof(STARTUPINFO)); startup_info.dwFlags = STARTF_USESTDHANDLES; startup_info.hStdInput = ::GetStdHandle(STD_INPUT_HANDLE); startup_info.hStdOutput = ::GetStdHandle(STD_OUTPUT_HANDLE); startup_info.hStdError = ::GetStdHandle(STD_ERROR_HANDLE); PROCESS_INFORMATION process_info; GTEST_DEATH_TEST_CHECK_( ::CreateProcessA( executable_path, const_cast(command_line.c_str()), nullptr, // Retuned process handle is not inheritable. nullptr, // Retuned thread handle is not inheritable. TRUE, // Child inherits all inheritable handles (for write_handle_). 0x0, // Default creation flags. nullptr, // Inherit the parent's environment. UnitTest::GetInstance()->original_working_dir(), &startup_info, &process_info) != FALSE); child_handle_.Reset(process_info.hProcess); ::CloseHandle(process_info.hThread); set_spawned(true); return OVERSEE_TEST; } # elif GTEST_OS_FUCHSIA class FuchsiaDeathTest : public DeathTestImpl { public: FuchsiaDeathTest(const char* a_statement, Matcher matcher, const char* file, int line) : DeathTestImpl(a_statement, std::move(matcher)), file_(file), line_(line) {} // All of these virtual functions are inherited from DeathTest. int Wait() override; TestRole AssumeRole() override; std::string GetErrorLogs() override; private: // The name of the file in which the death test is located. const char* const file_; // The line number on which the death test is located. const int line_; // The stderr data captured by the child process. std::string captured_stderr_; zx::process child_process_; zx::channel exception_channel_; zx::socket stderr_socket_; }; // Utility class for accumulating command-line arguments. class Arguments { public: Arguments() { args_.push_back(nullptr); } ~Arguments() { for (std::vector::iterator i = args_.begin(); i != args_.end(); ++i) { free(*i); } } void AddArgument(const char* argument) { args_.insert(args_.end() - 1, posix::StrDup(argument)); } template void AddArguments(const ::std::vector& arguments) { for (typename ::std::vector::const_iterator i = arguments.begin(); i != arguments.end(); ++i) { args_.insert(args_.end() - 1, posix::StrDup(i->c_str())); } } char* const* Argv() { return &args_[0]; } int size() { return args_.size() - 1; } private: std::vector args_; }; // Waits for the child in a death test to exit, returning its exit // status, or 0 if no child process exists. As a side effect, sets the // outcome data member. int FuchsiaDeathTest::Wait() { const int kProcessKey = 0; const int kSocketKey = 1; const int kExceptionKey = 2; if (!spawned()) return 0; // Create a port to wait for socket/task/exception events. zx_status_t status_zx; zx::port port; status_zx = zx::port::create(0, &port); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); // Register to wait for the child process to terminate. status_zx = child_process_.wait_async( port, kProcessKey, ZX_PROCESS_TERMINATED, ZX_WAIT_ASYNC_ONCE); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); // Register to wait for the socket to be readable or closed. status_zx = stderr_socket_.wait_async( port, kSocketKey, ZX_SOCKET_READABLE | ZX_SOCKET_PEER_CLOSED, ZX_WAIT_ASYNC_ONCE); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); // Register to wait for an exception. status_zx = exception_channel_.wait_async( port, kExceptionKey, ZX_CHANNEL_READABLE, ZX_WAIT_ASYNC_ONCE); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); bool process_terminated = false; bool socket_closed = false; do { zx_port_packet_t packet = {}; status_zx = port.wait(zx::time::infinite(), &packet); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); if (packet.key == kExceptionKey) { // Process encountered an exception. Kill it directly rather than // letting other handlers process the event. We will get a kProcessKey // event when the process actually terminates. status_zx = child_process_.kill(); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); } else if (packet.key == kProcessKey) { // Process terminated. GTEST_DEATH_TEST_CHECK_(ZX_PKT_IS_SIGNAL_ONE(packet.type)); GTEST_DEATH_TEST_CHECK_(packet.signal.observed & ZX_PROCESS_TERMINATED); process_terminated = true; } else if (packet.key == kSocketKey) { GTEST_DEATH_TEST_CHECK_(ZX_PKT_IS_SIGNAL_ONE(packet.type)); if (packet.signal.observed & ZX_SOCKET_READABLE) { // Read data from the socket. constexpr size_t kBufferSize = 1024; do { size_t old_length = captured_stderr_.length(); size_t bytes_read = 0; captured_stderr_.resize(old_length + kBufferSize); status_zx = stderr_socket_.read( 0, &captured_stderr_.front() + old_length, kBufferSize, &bytes_read); captured_stderr_.resize(old_length + bytes_read); } while (status_zx == ZX_OK); if (status_zx == ZX_ERR_PEER_CLOSED) { socket_closed = true; } else { GTEST_DEATH_TEST_CHECK_(status_zx == ZX_ERR_SHOULD_WAIT); status_zx = stderr_socket_.wait_async( port, kSocketKey, ZX_SOCKET_READABLE | ZX_SOCKET_PEER_CLOSED, ZX_WAIT_ASYNC_ONCE); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); } } else { GTEST_DEATH_TEST_CHECK_(packet.signal.observed & ZX_SOCKET_PEER_CLOSED); socket_closed = true; } } } while (!process_terminated && !socket_closed); ReadAndInterpretStatusByte(); zx_info_process_t buffer; status_zx = child_process_.get_info( ZX_INFO_PROCESS, &buffer, sizeof(buffer), nullptr, nullptr); GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); GTEST_DEATH_TEST_CHECK_(buffer.exited); set_status(buffer.return_code); return status(); } // The AssumeRole process for a Fuchsia death test. It creates a child // process with the same executable as the current process to run the // death test. The child process is given the --gtest_filter and // --gtest_internal_run_death_test flags such that it knows to run the // current death test only. DeathTest::TestRole FuchsiaDeathTest::AssumeRole() { const UnitTestImpl* const impl = GetUnitTestImpl(); const InternalRunDeathTestFlag* const flag = impl->internal_run_death_test_flag(); const TestInfo* const info = impl->current_test_info(); const int death_test_index = info->result()->death_test_count(); if (flag != nullptr) { // ParseInternalRunDeathTestFlag() has performed all the necessary // processing. set_write_fd(kFuchsiaReadPipeFd); return EXECUTE_TEST; } // Flush the log buffers since the log streams are shared with the child. FlushInfoLog(); // Build the child process command line. const std::string filter_flag = std::string("--") + GTEST_FLAG_PREFIX_ + kFilterFlag + "=" + info->test_suite_name() + "." + info->name(); const std::string internal_flag = std::string("--") + GTEST_FLAG_PREFIX_ + kInternalRunDeathTestFlag + "=" + file_ + "|" + StreamableToString(line_) + "|" + StreamableToString(death_test_index); Arguments args; args.AddArguments(GetInjectableArgvs()); args.AddArgument(filter_flag.c_str()); args.AddArgument(internal_flag.c_str()); // Build the pipe for communication with the child. zx_status_t status; zx_handle_t child_pipe_handle; int child_pipe_fd; status = fdio_pipe_half(&child_pipe_fd, &child_pipe_handle); GTEST_DEATH_TEST_CHECK_(status == ZX_OK); set_read_fd(child_pipe_fd); // Set the pipe handle for the child. fdio_spawn_action_t spawn_actions[2] = {}; fdio_spawn_action_t* add_handle_action = &spawn_actions[0]; add_handle_action->action = FDIO_SPAWN_ACTION_ADD_HANDLE; add_handle_action->h.id = PA_HND(PA_FD, kFuchsiaReadPipeFd); add_handle_action->h.handle = child_pipe_handle; // Create a socket pair will be used to receive the child process' stderr. zx::socket stderr_producer_socket; status = zx::socket::create(0, &stderr_producer_socket, &stderr_socket_); GTEST_DEATH_TEST_CHECK_(status >= 0); int stderr_producer_fd = -1; status = fdio_fd_create(stderr_producer_socket.release(), &stderr_producer_fd); GTEST_DEATH_TEST_CHECK_(status >= 0); // Make the stderr socket nonblocking. GTEST_DEATH_TEST_CHECK_(fcntl(stderr_producer_fd, F_SETFL, 0) == 0); fdio_spawn_action_t* add_stderr_action = &spawn_actions[1]; add_stderr_action->action = FDIO_SPAWN_ACTION_CLONE_FD; add_stderr_action->fd.local_fd = stderr_producer_fd; add_stderr_action->fd.target_fd = STDERR_FILENO; // Create a child job. zx_handle_t child_job = ZX_HANDLE_INVALID; status = zx_job_create(zx_job_default(), 0, & child_job); GTEST_DEATH_TEST_CHECK_(status == ZX_OK); zx_policy_basic_t policy; policy.condition = ZX_POL_NEW_ANY; policy.policy = ZX_POL_ACTION_ALLOW; status = zx_job_set_policy( child_job, ZX_JOB_POL_RELATIVE, ZX_JOB_POL_BASIC, &policy, 1); GTEST_DEATH_TEST_CHECK_(status == ZX_OK); // Create an exception channel attached to the |child_job|, to allow // us to suppress the system default exception handler from firing. status = zx_task_create_exception_channel( child_job, 0, exception_channel_.reset_and_get_address()); GTEST_DEATH_TEST_CHECK_(status == ZX_OK); // Spawn the child process. status = fdio_spawn_etc( child_job, FDIO_SPAWN_CLONE_ALL, args.Argv()[0], args.Argv(), nullptr, 2, spawn_actions, child_process_.reset_and_get_address(), nullptr); GTEST_DEATH_TEST_CHECK_(status == ZX_OK); set_spawned(true); return OVERSEE_TEST; } std::string FuchsiaDeathTest::GetErrorLogs() { return captured_stderr_; } #else // We are neither on Windows, nor on Fuchsia. // ForkingDeathTest provides implementations for most of the abstract // methods of the DeathTest interface. Only the AssumeRole method is // left undefined. class ForkingDeathTest : public DeathTestImpl { public: ForkingDeathTest(const char* statement, Matcher matcher); // All of these virtual functions are inherited from DeathTest. int Wait() override; protected: void set_child_pid(pid_t child_pid) { child_pid_ = child_pid; } private: // PID of child process during death test; 0 in the child process itself. pid_t child_pid_; }; // Constructs a ForkingDeathTest. ForkingDeathTest::ForkingDeathTest(const char* a_statement, Matcher matcher) : DeathTestImpl(a_statement, std::move(matcher)), child_pid_(-1) {} // Waits for the child in a death test to exit, returning its exit // status, or 0 if no child process exists. As a side effect, sets the // outcome data member. int ForkingDeathTest::Wait() { if (!spawned()) return 0; ReadAndInterpretStatusByte(); int status_value; GTEST_DEATH_TEST_CHECK_SYSCALL_(waitpid(child_pid_, &status_value, 0)); set_status(status_value); return status_value; } // A concrete death test class that forks, then immediately runs the test // in the child process. class NoExecDeathTest : public ForkingDeathTest { public: NoExecDeathTest(const char* a_statement, Matcher matcher) : ForkingDeathTest(a_statement, std::move(matcher)) {} TestRole AssumeRole() override; }; // The AssumeRole process for a fork-and-run death test. It implements a // straightforward fork, with a simple pipe to transmit the status byte. DeathTest::TestRole NoExecDeathTest::AssumeRole() { const size_t thread_count = GetThreadCount(); if (thread_count != 1) { GTEST_LOG_(WARNING) << DeathTestThreadWarning(thread_count); } int pipe_fd[2]; GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); DeathTest::set_last_death_test_message(""); CaptureStderr(); // When we fork the process below, the log file buffers are copied, but the // file descriptors are shared. We flush all log files here so that closing // the file descriptors in the child process doesn't throw off the // synchronization between descriptors and buffers in the parent process. // This is as close to the fork as possible to avoid a race condition in case // there are multiple threads running before the death test, and another // thread writes to the log file. FlushInfoLog(); const pid_t child_pid = fork(); GTEST_DEATH_TEST_CHECK_(child_pid != -1); set_child_pid(child_pid); if (child_pid == 0) { GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[0])); set_write_fd(pipe_fd[1]); // Redirects all logging to stderr in the child process to prevent // concurrent writes to the log files. We capture stderr in the parent // process and append the child process' output to a log. LogToStderr(); // Event forwarding to the listeners of event listener API mush be shut // down in death test subprocesses. GetUnitTestImpl()->listeners()->SuppressEventForwarding(); g_in_fast_death_test_child = true; return EXECUTE_TEST; } else { GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); set_read_fd(pipe_fd[0]); set_spawned(true); return OVERSEE_TEST; } } // A concrete death test class that forks and re-executes the main // program from the beginning, with command-line flags set that cause // only this specific death test to be run. class ExecDeathTest : public ForkingDeathTest { public: ExecDeathTest(const char* a_statement, Matcher matcher, const char* file, int line) : ForkingDeathTest(a_statement, std::move(matcher)), file_(file), line_(line) {} TestRole AssumeRole() override; private: static ::std::vector GetArgvsForDeathTestChildProcess() { ::std::vector args = GetInjectableArgvs(); # if defined(GTEST_EXTRA_DEATH_TEST_COMMAND_LINE_ARGS_) ::std::vector extra_args = GTEST_EXTRA_DEATH_TEST_COMMAND_LINE_ARGS_(); args.insert(args.end(), extra_args.begin(), extra_args.end()); # endif // defined(GTEST_EXTRA_DEATH_TEST_COMMAND_LINE_ARGS_) return args; } // The name of the file in which the death test is located. const char* const file_; // The line number on which the death test is located. const int line_; }; // Utility class for accumulating command-line arguments. class Arguments { public: Arguments() { args_.push_back(nullptr); } ~Arguments() { for (std::vector::iterator i = args_.begin(); i != args_.end(); ++i) { free(*i); } } void AddArgument(const char* argument) { args_.insert(args_.end() - 1, posix::StrDup(argument)); } template void AddArguments(const ::std::vector& arguments) { for (typename ::std::vector::const_iterator i = arguments.begin(); i != arguments.end(); ++i) { args_.insert(args_.end() - 1, posix::StrDup(i->c_str())); } } char* const* Argv() { return &args_[0]; } private: std::vector args_; }; // A struct that encompasses the arguments to the child process of a // threadsafe-style death test process. struct ExecDeathTestArgs { char* const* argv; // Command-line arguments for the child's call to exec int close_fd; // File descriptor to close; the read end of a pipe }; # if GTEST_OS_MAC inline char** GetEnviron() { // When Google Test is built as a framework on MacOS X, the environ variable // is unavailable. Apple's documentation (man environ) recommends using // _NSGetEnviron() instead. return *_NSGetEnviron(); } # else // Some POSIX platforms expect you to declare environ. extern "C" makes // it reside in the global namespace. extern "C" char** environ; inline char** GetEnviron() { return environ; } # endif // GTEST_OS_MAC # if !GTEST_OS_QNX // The main function for a threadsafe-style death test child process. // This function is called in a clone()-ed process and thus must avoid // any potentially unsafe operations like malloc or libc functions. static int ExecDeathTestChildMain(void* child_arg) { ExecDeathTestArgs* const args = static_cast(child_arg); GTEST_DEATH_TEST_CHECK_SYSCALL_(close(args->close_fd)); // We need to execute the test program in the same environment where // it was originally invoked. Therefore we change to the original // working directory first. const char* const original_dir = UnitTest::GetInstance()->original_working_dir(); // We can safely call chdir() as it's a direct system call. if (chdir(original_dir) != 0) { DeathTestAbort(std::string("chdir(\"") + original_dir + "\") failed: " + GetLastErrnoDescription()); return EXIT_FAILURE; } // We can safely call execve() as it's a direct system call. We // cannot use execvp() as it's a libc function and thus potentially // unsafe. Since execve() doesn't search the PATH, the user must // invoke the test program via a valid path that contains at least // one path separator. execve(args->argv[0], args->argv, GetEnviron()); DeathTestAbort(std::string("execve(") + args->argv[0] + ", ...) in " + original_dir + " failed: " + GetLastErrnoDescription()); return EXIT_FAILURE; } # endif // !GTEST_OS_QNX # if GTEST_HAS_CLONE // Two utility routines that together determine the direction the stack // grows. // This could be accomplished more elegantly by a single recursive // function, but we want to guard against the unlikely possibility of // a smart compiler optimizing the recursion away. // // GTEST_NO_INLINE_ is required to prevent GCC 4.6 from inlining // StackLowerThanAddress into StackGrowsDown, which then doesn't give // correct answer. static void StackLowerThanAddress(const void* ptr, bool* result) GTEST_NO_INLINE_; // HWAddressSanitizer add a random tag to the MSB of the local variable address, // making comparison result unpredictable. GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ static void StackLowerThanAddress(const void* ptr, bool* result) { int dummy; *result = (&dummy < ptr); } // Make sure AddressSanitizer does not tamper with the stack here. GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ static bool StackGrowsDown() { int dummy = 0; bool result = 0; StackLowerThanAddress(&dummy, &result); return result; } # endif // GTEST_HAS_CLONE // Spawns a child process with the same executable as the current process in // a thread-safe manner and instructs it to run the death test. The // implementation uses fork(2) + exec. On systems where clone(2) is // available, it is used instead, being slightly more thread-safe. On QNX, // fork supports only single-threaded environments, so this function uses // spawn(2) there instead. The function dies with an error message if // anything goes wrong. static pid_t ExecDeathTestSpawnChild(char* const* argv, int close_fd) { ExecDeathTestArgs args = { argv, close_fd }; pid_t child_pid = -1; # if GTEST_OS_QNX // Obtains the current directory and sets it to be closed in the child // process. const int cwd_fd = open(".", O_RDONLY); GTEST_DEATH_TEST_CHECK_(cwd_fd != -1); GTEST_DEATH_TEST_CHECK_SYSCALL_(fcntl(cwd_fd, F_SETFD, FD_CLOEXEC)); // We need to execute the test program in the same environment where // it was originally invoked. Therefore we change to the original // working directory first. const char* const original_dir = UnitTest::GetInstance()->original_working_dir(); // We can safely call chdir() as it's a direct system call. if (chdir(original_dir) != 0) { DeathTestAbort(std::string("chdir(\"") + original_dir + "\") failed: " + GetLastErrnoDescription()); return EXIT_FAILURE; } int fd_flags; // Set close_fd to be closed after spawn. GTEST_DEATH_TEST_CHECK_SYSCALL_(fd_flags = fcntl(close_fd, F_GETFD)); GTEST_DEATH_TEST_CHECK_SYSCALL_(fcntl(close_fd, F_SETFD, fd_flags | FD_CLOEXEC)); struct inheritance inherit = {0}; // spawn is a system call. child_pid = spawn(args.argv[0], 0, nullptr, &inherit, args.argv, GetEnviron()); // Restores the current working directory. GTEST_DEATH_TEST_CHECK_(fchdir(cwd_fd) != -1); GTEST_DEATH_TEST_CHECK_SYSCALL_(close(cwd_fd)); # else // GTEST_OS_QNX # if GTEST_OS_LINUX // When a SIGPROF signal is received while fork() or clone() are executing, // the process may hang. To avoid this, we ignore SIGPROF here and re-enable // it after the call to fork()/clone() is complete. struct sigaction saved_sigprof_action; struct sigaction ignore_sigprof_action; memset(&ignore_sigprof_action, 0, sizeof(ignore_sigprof_action)); sigemptyset(&ignore_sigprof_action.sa_mask); ignore_sigprof_action.sa_handler = SIG_IGN; GTEST_DEATH_TEST_CHECK_SYSCALL_(sigaction( SIGPROF, &ignore_sigprof_action, &saved_sigprof_action)); # endif // GTEST_OS_LINUX # if GTEST_HAS_CLONE const bool use_fork = GTEST_FLAG(death_test_use_fork); if (!use_fork) { static const bool stack_grows_down = StackGrowsDown(); const auto stack_size = static_cast(getpagesize()); // MMAP_ANONYMOUS is not defined on Mac, so we use MAP_ANON instead. void* const stack = mmap(nullptr, stack_size, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0); GTEST_DEATH_TEST_CHECK_(stack != MAP_FAILED); // Maximum stack alignment in bytes: For a downward-growing stack, this // amount is subtracted from size of the stack space to get an address // that is within the stack space and is aligned on all systems we care // about. As far as I know there is no ABI with stack alignment greater // than 64. We assume stack and stack_size already have alignment of // kMaxStackAlignment. const size_t kMaxStackAlignment = 64; void* const stack_top = static_cast(stack) + (stack_grows_down ? stack_size - kMaxStackAlignment : 0); GTEST_DEATH_TEST_CHECK_( static_cast(stack_size) > kMaxStackAlignment && reinterpret_cast(stack_top) % kMaxStackAlignment == 0); child_pid = clone(&ExecDeathTestChildMain, stack_top, SIGCHLD, &args); GTEST_DEATH_TEST_CHECK_(munmap(stack, stack_size) != -1); } # else const bool use_fork = true; # endif // GTEST_HAS_CLONE if (use_fork && (child_pid = fork()) == 0) { ExecDeathTestChildMain(&args); _exit(0); } # endif // GTEST_OS_QNX # if GTEST_OS_LINUX GTEST_DEATH_TEST_CHECK_SYSCALL_( sigaction(SIGPROF, &saved_sigprof_action, nullptr)); # endif // GTEST_OS_LINUX GTEST_DEATH_TEST_CHECK_(child_pid != -1); return child_pid; } // The AssumeRole process for a fork-and-exec death test. It re-executes the // main program from the beginning, setting the --gtest_filter // and --gtest_internal_run_death_test flags to cause only the current // death test to be re-run. DeathTest::TestRole ExecDeathTest::AssumeRole() { const UnitTestImpl* const impl = GetUnitTestImpl(); const InternalRunDeathTestFlag* const flag = impl->internal_run_death_test_flag(); const TestInfo* const info = impl->current_test_info(); const int death_test_index = info->result()->death_test_count(); if (flag != nullptr) { set_write_fd(flag->write_fd()); return EXECUTE_TEST; } int pipe_fd[2]; GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); // Clear the close-on-exec flag on the write end of the pipe, lest // it be closed when the child process does an exec: GTEST_DEATH_TEST_CHECK_(fcntl(pipe_fd[1], F_SETFD, 0) != -1); const std::string filter_flag = std::string("--") + GTEST_FLAG_PREFIX_ + kFilterFlag + "=" + info->test_suite_name() + "." + info->name(); const std::string internal_flag = std::string("--") + GTEST_FLAG_PREFIX_ + kInternalRunDeathTestFlag + "=" + file_ + "|" + StreamableToString(line_) + "|" + StreamableToString(death_test_index) + "|" + StreamableToString(pipe_fd[1]); Arguments args; args.AddArguments(GetArgvsForDeathTestChildProcess()); args.AddArgument(filter_flag.c_str()); args.AddArgument(internal_flag.c_str()); DeathTest::set_last_death_test_message(""); CaptureStderr(); // See the comment in NoExecDeathTest::AssumeRole for why the next line // is necessary. FlushInfoLog(); const pid_t child_pid = ExecDeathTestSpawnChild(args.Argv(), pipe_fd[0]); GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); set_child_pid(child_pid); set_read_fd(pipe_fd[0]); set_spawned(true); return OVERSEE_TEST; } # endif // !GTEST_OS_WINDOWS // Creates a concrete DeathTest-derived class that depends on the // --gtest_death_test_style flag, and sets the pointer pointed to // by the "test" argument to its address. If the test should be // skipped, sets that pointer to NULL. Returns true, unless the // flag is set to an invalid value. bool DefaultDeathTestFactory::Create(const char* statement, Matcher matcher, const char* file, int line, DeathTest** test) { UnitTestImpl* const impl = GetUnitTestImpl(); const InternalRunDeathTestFlag* const flag = impl->internal_run_death_test_flag(); const int death_test_index = impl->current_test_info() ->increment_death_test_count(); if (flag != nullptr) { if (death_test_index > flag->index()) { DeathTest::set_last_death_test_message( "Death test count (" + StreamableToString(death_test_index) + ") somehow exceeded expected maximum (" + StreamableToString(flag->index()) + ")"); return false; } if (!(flag->file() == file && flag->line() == line && flag->index() == death_test_index)) { *test = nullptr; return true; } } # if GTEST_OS_WINDOWS if (GTEST_FLAG(death_test_style) == "threadsafe" || GTEST_FLAG(death_test_style) == "fast") { *test = new WindowsDeathTest(statement, std::move(matcher), file, line); } # elif GTEST_OS_FUCHSIA if (GTEST_FLAG(death_test_style) == "threadsafe" || GTEST_FLAG(death_test_style) == "fast") { *test = new FuchsiaDeathTest(statement, std::move(matcher), file, line); } # else if (GTEST_FLAG(death_test_style) == "threadsafe") { *test = new ExecDeathTest(statement, std::move(matcher), file, line); } else if (GTEST_FLAG(death_test_style) == "fast") { *test = new NoExecDeathTest(statement, std::move(matcher)); } # endif // GTEST_OS_WINDOWS else { // NOLINT - this is more readable than unbalanced brackets inside #if. DeathTest::set_last_death_test_message( "Unknown death test style \"" + GTEST_FLAG(death_test_style) + "\" encountered"); return false; } return true; } # if GTEST_OS_WINDOWS // Recreates the pipe and event handles from the provided parameters, // signals the event, and returns a file descriptor wrapped around the pipe // handle. This function is called in the child process only. static int GetStatusFileDescriptor(unsigned int parent_process_id, size_t write_handle_as_size_t, size_t event_handle_as_size_t) { AutoHandle parent_process_handle(::OpenProcess(PROCESS_DUP_HANDLE, FALSE, // Non-inheritable. parent_process_id)); if (parent_process_handle.Get() == INVALID_HANDLE_VALUE) { DeathTestAbort("Unable to open parent process " + StreamableToString(parent_process_id)); } GTEST_CHECK_(sizeof(HANDLE) <= sizeof(size_t)); const HANDLE write_handle = reinterpret_cast(write_handle_as_size_t); HANDLE dup_write_handle; // The newly initialized handle is accessible only in the parent // process. To obtain one accessible within the child, we need to use // DuplicateHandle. if (!::DuplicateHandle(parent_process_handle.Get(), write_handle, ::GetCurrentProcess(), &dup_write_handle, 0x0, // Requested privileges ignored since // DUPLICATE_SAME_ACCESS is used. FALSE, // Request non-inheritable handler. DUPLICATE_SAME_ACCESS)) { DeathTestAbort("Unable to duplicate the pipe handle " + StreamableToString(write_handle_as_size_t) + " from the parent process " + StreamableToString(parent_process_id)); } const HANDLE event_handle = reinterpret_cast(event_handle_as_size_t); HANDLE dup_event_handle; if (!::DuplicateHandle(parent_process_handle.Get(), event_handle, ::GetCurrentProcess(), &dup_event_handle, 0x0, FALSE, DUPLICATE_SAME_ACCESS)) { DeathTestAbort("Unable to duplicate the event handle " + StreamableToString(event_handle_as_size_t) + " from the parent process " + StreamableToString(parent_process_id)); } const int write_fd = ::_open_osfhandle(reinterpret_cast(dup_write_handle), O_APPEND); if (write_fd == -1) { DeathTestAbort("Unable to convert pipe handle " + StreamableToString(write_handle_as_size_t) + " to a file descriptor"); } // Signals the parent that the write end of the pipe has been acquired // so the parent can release its own write end. ::SetEvent(dup_event_handle); return write_fd; } # endif // GTEST_OS_WINDOWS // Returns a newly created InternalRunDeathTestFlag object with fields // initialized from the GTEST_FLAG(internal_run_death_test) flag if // the flag is specified; otherwise returns NULL. InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag() { if (GTEST_FLAG(internal_run_death_test) == "") return nullptr; // GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we // can use it here. int line = -1; int index = -1; ::std::vector< ::std::string> fields; SplitString(GTEST_FLAG(internal_run_death_test).c_str(), '|', &fields); int write_fd = -1; # if GTEST_OS_WINDOWS unsigned int parent_process_id = 0; size_t write_handle_as_size_t = 0; size_t event_handle_as_size_t = 0; if (fields.size() != 6 || !ParseNaturalNumber(fields[1], &line) || !ParseNaturalNumber(fields[2], &index) || !ParseNaturalNumber(fields[3], &parent_process_id) || !ParseNaturalNumber(fields[4], &write_handle_as_size_t) || !ParseNaturalNumber(fields[5], &event_handle_as_size_t)) { DeathTestAbort("Bad --gtest_internal_run_death_test flag: " + GTEST_FLAG(internal_run_death_test)); } write_fd = GetStatusFileDescriptor(parent_process_id, write_handle_as_size_t, event_handle_as_size_t); # elif GTEST_OS_FUCHSIA if (fields.size() != 3 || !ParseNaturalNumber(fields[1], &line) || !ParseNaturalNumber(fields[2], &index)) { DeathTestAbort("Bad --gtest_internal_run_death_test flag: " + GTEST_FLAG(internal_run_death_test)); } # else if (fields.size() != 4 || !ParseNaturalNumber(fields[1], &line) || !ParseNaturalNumber(fields[2], &index) || !ParseNaturalNumber(fields[3], &write_fd)) { DeathTestAbort("Bad --gtest_internal_run_death_test flag: " + GTEST_FLAG(internal_run_death_test)); } # endif // GTEST_OS_WINDOWS return new InternalRunDeathTestFlag(fields[0], line, index, write_fd); } } // namespace internal #endif // GTEST_HAS_DEATH_TEST } // namespace testing // Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #if GTEST_OS_WINDOWS_MOBILE # include #elif GTEST_OS_WINDOWS # include # include #else # include # include // Some Linux distributions define PATH_MAX here. #endif // GTEST_OS_WINDOWS_MOBILE #if GTEST_OS_WINDOWS # define GTEST_PATH_MAX_ _MAX_PATH #elif defined(PATH_MAX) # define GTEST_PATH_MAX_ PATH_MAX #elif defined(_XOPEN_PATH_MAX) # define GTEST_PATH_MAX_ _XOPEN_PATH_MAX #else # define GTEST_PATH_MAX_ _POSIX_PATH_MAX #endif // GTEST_OS_WINDOWS namespace testing { namespace internal { #if GTEST_OS_WINDOWS // On Windows, '\\' is the standard path separator, but many tools and the // Windows API also accept '/' as an alternate path separator. Unless otherwise // noted, a file path can contain either kind of path separators, or a mixture // of them. const char kPathSeparator = '\\'; const char kAlternatePathSeparator = '/'; const char kAlternatePathSeparatorString[] = "/"; # if GTEST_OS_WINDOWS_MOBILE // Windows CE doesn't have a current directory. You should not use // the current directory in tests on Windows CE, but this at least // provides a reasonable fallback. const char kCurrentDirectoryString[] = "\\"; // Windows CE doesn't define INVALID_FILE_ATTRIBUTES const DWORD kInvalidFileAttributes = 0xffffffff; # else const char kCurrentDirectoryString[] = ".\\"; # endif // GTEST_OS_WINDOWS_MOBILE #else const char kPathSeparator = '/'; const char kCurrentDirectoryString[] = "./"; #endif // GTEST_OS_WINDOWS // Returns whether the given character is a valid path separator. static bool IsPathSeparator(char c) { #if GTEST_HAS_ALT_PATH_SEP_ return (c == kPathSeparator) || (c == kAlternatePathSeparator); #else return c == kPathSeparator; #endif } // Returns the current working directory, or "" if unsuccessful. FilePath FilePath::GetCurrentDir() { #if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_PHONE || \ GTEST_OS_WINDOWS_RT || ARDUINO || defined(ESP_PLATFORM) // These platforms do not have a current directory, so we just return // something reasonable. return FilePath(kCurrentDirectoryString); #elif GTEST_OS_WINDOWS char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; return FilePath(_getcwd(cwd, sizeof(cwd)) == nullptr ? "" : cwd); #else char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; char* result = getcwd(cwd, sizeof(cwd)); # if GTEST_OS_NACL // getcwd will likely fail in NaCl due to the sandbox, so return something // reasonable. The user may have provided a shim implementation for getcwd, // however, so fallback only when failure is detected. return FilePath(result == nullptr ? kCurrentDirectoryString : cwd); # endif // GTEST_OS_NACL return FilePath(result == nullptr ? "" : cwd); #endif // GTEST_OS_WINDOWS_MOBILE } // Returns a copy of the FilePath with the case-insensitive extension removed. // Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns // FilePath("dir/file"). If a case-insensitive extension is not // found, returns a copy of the original FilePath. FilePath FilePath::RemoveExtension(const char* extension) const { const std::string dot_extension = std::string(".") + extension; if (String::EndsWithCaseInsensitive(pathname_, dot_extension)) { return FilePath(pathname_.substr( 0, pathname_.length() - dot_extension.length())); } return *this; } // Returns a pointer to the last occurrence of a valid path separator in // the FilePath. On Windows, for example, both '/' and '\' are valid path // separators. Returns NULL if no path separator was found. const char* FilePath::FindLastPathSeparator() const { const char* const last_sep = strrchr(c_str(), kPathSeparator); #if GTEST_HAS_ALT_PATH_SEP_ const char* const last_alt_sep = strrchr(c_str(), kAlternatePathSeparator); // Comparing two pointers of which only one is NULL is undefined. if (last_alt_sep != nullptr && (last_sep == nullptr || last_alt_sep > last_sep)) { return last_alt_sep; } #endif return last_sep; } // Returns a copy of the FilePath with the directory part removed. // Example: FilePath("path/to/file").RemoveDirectoryName() returns // FilePath("file"). If there is no directory part ("just_a_file"), it returns // the FilePath unmodified. If there is no file part ("just_a_dir/") it // returns an empty FilePath (""). // On Windows platform, '\' is the path separator, otherwise it is '/'. FilePath FilePath::RemoveDirectoryName() const { const char* const last_sep = FindLastPathSeparator(); return last_sep ? FilePath(last_sep + 1) : *this; } // RemoveFileName returns the directory path with the filename removed. // Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". // If the FilePath is "a_file" or "/a_file", RemoveFileName returns // FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does // not have a file, like "just/a/dir/", it returns the FilePath unmodified. // On Windows platform, '\' is the path separator, otherwise it is '/'. FilePath FilePath::RemoveFileName() const { const char* const last_sep = FindLastPathSeparator(); std::string dir; if (last_sep) { dir = std::string(c_str(), static_cast(last_sep + 1 - c_str())); } else { dir = kCurrentDirectoryString; } return FilePath(dir); } // Helper functions for naming files in a directory for xml output. // Given directory = "dir", base_name = "test", number = 0, // extension = "xml", returns "dir/test.xml". If number is greater // than zero (e.g., 12), returns "dir/test_12.xml". // On Windows platform, uses \ as the separator rather than /. FilePath FilePath::MakeFileName(const FilePath& directory, const FilePath& base_name, int number, const char* extension) { std::string file; if (number == 0) { file = base_name.string() + "." + extension; } else { file = base_name.string() + "_" + StreamableToString(number) + "." + extension; } return ConcatPaths(directory, FilePath(file)); } // Given directory = "dir", relative_path = "test.xml", returns "dir/test.xml". // On Windows, uses \ as the separator rather than /. FilePath FilePath::ConcatPaths(const FilePath& directory, const FilePath& relative_path) { if (directory.IsEmpty()) return relative_path; const FilePath dir(directory.RemoveTrailingPathSeparator()); return FilePath(dir.string() + kPathSeparator + relative_path.string()); } // Returns true if pathname describes something findable in the file-system, // either a file, directory, or whatever. bool FilePath::FileOrDirectoryExists() const { #if GTEST_OS_WINDOWS_MOBILE LPCWSTR unicode = String::AnsiToUtf16(pathname_.c_str()); const DWORD attributes = GetFileAttributes(unicode); delete [] unicode; return attributes != kInvalidFileAttributes; #else posix::StatStruct file_stat; return posix::Stat(pathname_.c_str(), &file_stat) == 0; #endif // GTEST_OS_WINDOWS_MOBILE } // Returns true if pathname describes a directory in the file-system // that exists. bool FilePath::DirectoryExists() const { bool result = false; #if GTEST_OS_WINDOWS // Don't strip off trailing separator if path is a root directory on // Windows (like "C:\\"). const FilePath& path(IsRootDirectory() ? *this : RemoveTrailingPathSeparator()); #else const FilePath& path(*this); #endif #if GTEST_OS_WINDOWS_MOBILE LPCWSTR unicode = String::AnsiToUtf16(path.c_str()); const DWORD attributes = GetFileAttributes(unicode); delete [] unicode; if ((attributes != kInvalidFileAttributes) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { result = true; } #else posix::StatStruct file_stat; result = posix::Stat(path.c_str(), &file_stat) == 0 && posix::IsDir(file_stat); #endif // GTEST_OS_WINDOWS_MOBILE return result; } // Returns true if pathname describes a root directory. (Windows has one // root directory per disk drive.) bool FilePath::IsRootDirectory() const { #if GTEST_OS_WINDOWS return pathname_.length() == 3 && IsAbsolutePath(); #else return pathname_.length() == 1 && IsPathSeparator(pathname_.c_str()[0]); #endif } // Returns true if pathname describes an absolute path. bool FilePath::IsAbsolutePath() const { const char* const name = pathname_.c_str(); #if GTEST_OS_WINDOWS return pathname_.length() >= 3 && ((name[0] >= 'a' && name[0] <= 'z') || (name[0] >= 'A' && name[0] <= 'Z')) && name[1] == ':' && IsPathSeparator(name[2]); #else return IsPathSeparator(name[0]); #endif } // Returns a pathname for a file that does not currently exist. The pathname // will be directory/base_name.extension or // directory/base_name_.extension if directory/base_name.extension // already exists. The number will be incremented until a pathname is found // that does not already exist. // Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. // There could be a race condition if two or more processes are calling this // function at the same time -- they could both pick the same filename. FilePath FilePath::GenerateUniqueFileName(const FilePath& directory, const FilePath& base_name, const char* extension) { FilePath full_pathname; int number = 0; do { full_pathname.Set(MakeFileName(directory, base_name, number++, extension)); } while (full_pathname.FileOrDirectoryExists()); return full_pathname; } // Returns true if FilePath ends with a path separator, which indicates that // it is intended to represent a directory. Returns false otherwise. // This does NOT check that a directory (or file) actually exists. bool FilePath::IsDirectory() const { return !pathname_.empty() && IsPathSeparator(pathname_.c_str()[pathname_.length() - 1]); } // Create directories so that path exists. Returns true if successful or if // the directories already exist; returns false if unable to create directories // for any reason. bool FilePath::CreateDirectoriesRecursively() const { if (!this->IsDirectory()) { return false; } if (pathname_.length() == 0 || this->DirectoryExists()) { return true; } const FilePath parent(this->RemoveTrailingPathSeparator().RemoveFileName()); return parent.CreateDirectoriesRecursively() && this->CreateFolder(); } // Create the directory so that path exists. Returns true if successful or // if the directory already exists; returns false if unable to create the // directory for any reason, including if the parent directory does not // exist. Not named "CreateDirectory" because that's a macro on Windows. bool FilePath::CreateFolder() const { #if GTEST_OS_WINDOWS_MOBILE FilePath removed_sep(this->RemoveTrailingPathSeparator()); LPCWSTR unicode = String::AnsiToUtf16(removed_sep.c_str()); int result = CreateDirectory(unicode, nullptr) ? 0 : -1; delete [] unicode; #elif GTEST_OS_WINDOWS int result = _mkdir(pathname_.c_str()); #else int result = mkdir(pathname_.c_str(), 0777); #endif // GTEST_OS_WINDOWS_MOBILE if (result == -1) { return this->DirectoryExists(); // An error is OK if the directory exists. } return true; // No error. } // If input name has a trailing separator character, remove it and return the // name, otherwise return the name string unmodified. // On Windows platform, uses \ as the separator, other platforms use /. FilePath FilePath::RemoveTrailingPathSeparator() const { return IsDirectory() ? FilePath(pathname_.substr(0, pathname_.length() - 1)) : *this; } // Removes any redundant separators that might be in the pathname. // For example, "bar///foo" becomes "bar/foo". Does not eliminate other // redundancies that might be in a pathname involving "." or "..". void FilePath::Normalize() { if (pathname_.c_str() == nullptr) { pathname_ = ""; return; } const char* src = pathname_.c_str(); char* const dest = new char[pathname_.length() + 1]; char* dest_ptr = dest; memset(dest_ptr, 0, pathname_.length() + 1); while (*src != '\0') { *dest_ptr = *src; if (!IsPathSeparator(*src)) { src++; } else { #if GTEST_HAS_ALT_PATH_SEP_ if (*dest_ptr == kAlternatePathSeparator) { *dest_ptr = kPathSeparator; } #endif while (IsPathSeparator(*src)) src++; } dest_ptr++; } *dest_ptr = '\0'; pathname_ = dest; delete[] dest; } } // namespace internal } // namespace testing // Copyright 2007, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // The Google C++ Testing and Mocking Framework (Google Test) // // This file implements just enough of the matcher interface to allow // EXPECT_DEATH and friends to accept a matcher argument. #include namespace testing { // Constructs a matcher that matches a const std::string& whose value is // equal to s. Matcher::Matcher(const std::string& s) { *this = Eq(s); } // Constructs a matcher that matches a const std::string& whose value is // equal to s. Matcher::Matcher(const char* s) { *this = Eq(std::string(s)); } // Constructs a matcher that matches a std::string whose value is equal to // s. Matcher::Matcher(const std::string& s) { *this = Eq(s); } // Constructs a matcher that matches a std::string whose value is equal to // s. Matcher::Matcher(const char* s) { *this = Eq(std::string(s)); } #if GTEST_HAS_ABSL // Constructs a matcher that matches a const absl::string_view& whose value is // equal to s. Matcher::Matcher(const std::string& s) { *this = Eq(s); } // Constructs a matcher that matches a const absl::string_view& whose value is // equal to s. Matcher::Matcher(const char* s) { *this = Eq(std::string(s)); } // Constructs a matcher that matches a const absl::string_view& whose value is // equal to s. Matcher::Matcher(absl::string_view s) { *this = Eq(std::string(s)); } // Constructs a matcher that matches a absl::string_view whose value is equal to // s. Matcher::Matcher(const std::string& s) { *this = Eq(s); } // Constructs a matcher that matches a absl::string_view whose value is equal to // s. Matcher::Matcher(const char* s) { *this = Eq(std::string(s)); } // Constructs a matcher that matches a absl::string_view whose value is equal to // s. Matcher::Matcher(absl::string_view s) { *this = Eq(std::string(s)); } #endif // GTEST_HAS_ABSL } // namespace testing // Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include #include #include #include #if GTEST_OS_WINDOWS # include # include # include # include // Used in ThreadLocal. # ifdef _MSC_VER # include # endif // _MSC_VER #else # include #endif // GTEST_OS_WINDOWS #if GTEST_OS_MAC # include # include # include #endif // GTEST_OS_MAC #if GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD || \ GTEST_OS_NETBSD || GTEST_OS_OPENBSD # include # if GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD # include # endif #endif #if GTEST_OS_QNX # include # include # include #endif // GTEST_OS_QNX #if GTEST_OS_AIX # include # include #endif // GTEST_OS_AIX #if GTEST_OS_FUCHSIA # include # include #endif // GTEST_OS_FUCHSIA namespace testing { namespace internal { #if defined(_MSC_VER) || defined(__BORLANDC__) // MSVC and C++Builder do not provide a definition of STDERR_FILENO. const int kStdOutFileno = 1; const int kStdErrFileno = 2; #else const int kStdOutFileno = STDOUT_FILENO; const int kStdErrFileno = STDERR_FILENO; #endif // _MSC_VER #if GTEST_OS_LINUX namespace { template T ReadProcFileField(const std::string& filename, int field) { std::string dummy; std::ifstream file(filename.c_str()); while (field-- > 0) { file >> dummy; } T output = 0; file >> output; return output; } } // namespace // Returns the number of active threads, or 0 when there is an error. size_t GetThreadCount() { const std::string filename = (Message() << "/proc/" << getpid() << "/stat").GetString(); return ReadProcFileField(filename, 19); } #elif GTEST_OS_MAC size_t GetThreadCount() { const task_t task = mach_task_self(); mach_msg_type_number_t thread_count; thread_act_array_t thread_list; const kern_return_t status = task_threads(task, &thread_list, &thread_count); if (status == KERN_SUCCESS) { // task_threads allocates resources in thread_list and we need to free them // to avoid leaks. vm_deallocate(task, reinterpret_cast(thread_list), sizeof(thread_t) * thread_count); return static_cast(thread_count); } else { return 0; } } #elif GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD || \ GTEST_OS_NETBSD #if GTEST_OS_NETBSD #undef KERN_PROC #define KERN_PROC KERN_PROC2 #define kinfo_proc kinfo_proc2 #endif #if GTEST_OS_DRAGONFLY #define KP_NLWP(kp) (kp.kp_nthreads) #elif GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD #define KP_NLWP(kp) (kp.ki_numthreads) #elif GTEST_OS_NETBSD #define KP_NLWP(kp) (kp.p_nlwps) #endif // Returns the number of threads running in the process, or 0 to indicate that // we cannot detect it. size_t GetThreadCount() { int mib[] = { CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid(), #if GTEST_OS_NETBSD sizeof(struct kinfo_proc), 1, #endif }; u_int miblen = sizeof(mib) / sizeof(mib[0]); struct kinfo_proc info; size_t size = sizeof(info); if (sysctl(mib, miblen, &info, &size, NULL, 0)) { return 0; } return static_cast(KP_NLWP(info)); } #elif GTEST_OS_OPENBSD // Returns the number of threads running in the process, or 0 to indicate that // we cannot detect it. size_t GetThreadCount() { int mib[] = { CTL_KERN, KERN_PROC, KERN_PROC_PID | KERN_PROC_SHOW_THREADS, getpid(), sizeof(struct kinfo_proc), 0, }; u_int miblen = sizeof(mib) / sizeof(mib[0]); // get number of structs size_t size; if (sysctl(mib, miblen, NULL, &size, NULL, 0)) { return 0; } mib[5] = size / mib[4]; // populate array of structs struct kinfo_proc info[mib[5]]; if (sysctl(mib, miblen, &info, &size, NULL, 0)) { return 0; } // exclude empty members int nthreads = 0; for (int i = 0; i < size / mib[4]; i++) { if (info[i].p_tid != -1) nthreads++; } return nthreads; } #elif GTEST_OS_QNX // Returns the number of threads running in the process, or 0 to indicate that // we cannot detect it. size_t GetThreadCount() { const int fd = open("/proc/self/as", O_RDONLY); if (fd < 0) { return 0; } procfs_info process_info; const int status = devctl(fd, DCMD_PROC_INFO, &process_info, sizeof(process_info), nullptr); close(fd); if (status == EOK) { return static_cast(process_info.num_threads); } else { return 0; } } #elif GTEST_OS_AIX size_t GetThreadCount() { struct procentry64 entry; pid_t pid = getpid(); int status = getprocs64(&entry, sizeof(entry), nullptr, 0, &pid, 1); if (status == 1) { return entry.pi_thcount; } else { return 0; } } #elif GTEST_OS_FUCHSIA size_t GetThreadCount() { int dummy_buffer; size_t avail; zx_status_t status = zx_object_get_info( zx_process_self(), ZX_INFO_PROCESS_THREADS, &dummy_buffer, 0, nullptr, &avail); if (status == ZX_OK) { return avail; } else { return 0; } } #else size_t GetThreadCount() { // There's no portable way to detect the number of threads, so we just // return 0 to indicate that we cannot detect it. return 0; } #endif // GTEST_OS_LINUX #if GTEST_IS_THREADSAFE && GTEST_OS_WINDOWS void SleepMilliseconds(int n) { ::Sleep(static_cast(n)); } AutoHandle::AutoHandle() : handle_(INVALID_HANDLE_VALUE) {} AutoHandle::AutoHandle(Handle handle) : handle_(handle) {} AutoHandle::~AutoHandle() { Reset(); } AutoHandle::Handle AutoHandle::Get() const { return handle_; } void AutoHandle::Reset() { Reset(INVALID_HANDLE_VALUE); } void AutoHandle::Reset(HANDLE handle) { // Resetting with the same handle we already own is invalid. if (handle_ != handle) { if (IsCloseable()) { ::CloseHandle(handle_); } handle_ = handle; } else { GTEST_CHECK_(!IsCloseable()) << "Resetting a valid handle to itself is likely a programmer error " "and thus not allowed."; } } bool AutoHandle::IsCloseable() const { // Different Windows APIs may use either of these values to represent an // invalid handle. return handle_ != nullptr && handle_ != INVALID_HANDLE_VALUE; } Notification::Notification() : event_(::CreateEvent(nullptr, // Default security attributes. TRUE, // Do not reset automatically. FALSE, // Initially unset. nullptr)) { // Anonymous event. GTEST_CHECK_(event_.Get() != nullptr); } void Notification::Notify() { GTEST_CHECK_(::SetEvent(event_.Get()) != FALSE); } void Notification::WaitForNotification() { GTEST_CHECK_( ::WaitForSingleObject(event_.Get(), INFINITE) == WAIT_OBJECT_0); } Mutex::Mutex() : owner_thread_id_(0), type_(kDynamic), critical_section_init_phase_(0), critical_section_(new CRITICAL_SECTION) { ::InitializeCriticalSection(critical_section_); } Mutex::~Mutex() { // Static mutexes are leaked intentionally. It is not thread-safe to try // to clean them up. if (type_ == kDynamic) { ::DeleteCriticalSection(critical_section_); delete critical_section_; critical_section_ = nullptr; } } void Mutex::Lock() { ThreadSafeLazyInit(); ::EnterCriticalSection(critical_section_); owner_thread_id_ = ::GetCurrentThreadId(); } void Mutex::Unlock() { ThreadSafeLazyInit(); // We don't protect writing to owner_thread_id_ here, as it's the // caller's responsibility to ensure that the current thread holds the // mutex when this is called. owner_thread_id_ = 0; ::LeaveCriticalSection(critical_section_); } // Does nothing if the current thread holds the mutex. Otherwise, crashes // with high probability. void Mutex::AssertHeld() { ThreadSafeLazyInit(); GTEST_CHECK_(owner_thread_id_ == ::GetCurrentThreadId()) << "The current thread is not holding the mutex @" << this; } namespace { #ifdef _MSC_VER // Use the RAII idiom to flag mem allocs that are intentionally never // deallocated. The motivation is to silence the false positive mem leaks // that are reported by the debug version of MS's CRT which can only detect // if an alloc is missing a matching deallocation. // Example: // MemoryIsNotDeallocated memory_is_not_deallocated; // critical_section_ = new CRITICAL_SECTION; // class MemoryIsNotDeallocated { public: MemoryIsNotDeallocated() : old_crtdbg_flag_(0) { old_crtdbg_flag_ = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); // Set heap allocation block type to _IGNORE_BLOCK so that MS debug CRT // doesn't report mem leak if there's no matching deallocation. _CrtSetDbgFlag(old_crtdbg_flag_ & ~_CRTDBG_ALLOC_MEM_DF); } ~MemoryIsNotDeallocated() { // Restore the original _CRTDBG_ALLOC_MEM_DF flag _CrtSetDbgFlag(old_crtdbg_flag_); } private: int old_crtdbg_flag_; GTEST_DISALLOW_COPY_AND_ASSIGN_(MemoryIsNotDeallocated); }; #endif // _MSC_VER } // namespace // Initializes owner_thread_id_ and critical_section_ in static mutexes. void Mutex::ThreadSafeLazyInit() { // Dynamic mutexes are initialized in the constructor. if (type_ == kStatic) { switch ( ::InterlockedCompareExchange(&critical_section_init_phase_, 1L, 0L)) { case 0: // If critical_section_init_phase_ was 0 before the exchange, we // are the first to test it and need to perform the initialization. owner_thread_id_ = 0; { // Use RAII to flag that following mem alloc is never deallocated. #ifdef _MSC_VER MemoryIsNotDeallocated memory_is_not_deallocated; #endif // _MSC_VER critical_section_ = new CRITICAL_SECTION; } ::InitializeCriticalSection(critical_section_); // Updates the critical_section_init_phase_ to 2 to signal // initialization complete. GTEST_CHECK_(::InterlockedCompareExchange( &critical_section_init_phase_, 2L, 1L) == 1L); break; case 1: // Somebody else is already initializing the mutex; spin until they // are done. while (::InterlockedCompareExchange(&critical_section_init_phase_, 2L, 2L) != 2L) { // Possibly yields the rest of the thread's time slice to other // threads. ::Sleep(0); } break; case 2: break; // The mutex is already initialized and ready for use. default: GTEST_CHECK_(false) << "Unexpected value of critical_section_init_phase_ " << "while initializing a static mutex."; } } } namespace { class ThreadWithParamSupport : public ThreadWithParamBase { public: static HANDLE CreateThread(Runnable* runnable, Notification* thread_can_start) { ThreadMainParam* param = new ThreadMainParam(runnable, thread_can_start); DWORD thread_id; HANDLE thread_handle = ::CreateThread( nullptr, // Default security. 0, // Default stack size. &ThreadWithParamSupport::ThreadMain, param, // Parameter to ThreadMainStatic 0x0, // Default creation flags. &thread_id); // Need a valid pointer for the call to work under Win98. GTEST_CHECK_(thread_handle != nullptr) << "CreateThread failed with error " << ::GetLastError() << "."; if (thread_handle == nullptr) { delete param; } return thread_handle; } private: struct ThreadMainParam { ThreadMainParam(Runnable* runnable, Notification* thread_can_start) : runnable_(runnable), thread_can_start_(thread_can_start) { } std::unique_ptr runnable_; // Does not own. Notification* thread_can_start_; }; static DWORD WINAPI ThreadMain(void* ptr) { // Transfers ownership. std::unique_ptr param(static_cast(ptr)); if (param->thread_can_start_ != nullptr) param->thread_can_start_->WaitForNotification(); param->runnable_->Run(); return 0; } // Prohibit instantiation. ThreadWithParamSupport(); GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParamSupport); }; } // namespace ThreadWithParamBase::ThreadWithParamBase(Runnable *runnable, Notification* thread_can_start) : thread_(ThreadWithParamSupport::CreateThread(runnable, thread_can_start)) { } ThreadWithParamBase::~ThreadWithParamBase() { Join(); } void ThreadWithParamBase::Join() { GTEST_CHECK_(::WaitForSingleObject(thread_.Get(), INFINITE) == WAIT_OBJECT_0) << "Failed to join the thread with error " << ::GetLastError() << "."; } // Maps a thread to a set of ThreadIdToThreadLocals that have values // instantiated on that thread and notifies them when the thread exits. A // ThreadLocal instance is expected to persist until all threads it has // values on have terminated. class ThreadLocalRegistryImpl { public: // Registers thread_local_instance as having value on the current thread. // Returns a value that can be used to identify the thread from other threads. static ThreadLocalValueHolderBase* GetValueOnCurrentThread( const ThreadLocalBase* thread_local_instance) { DWORD current_thread = ::GetCurrentThreadId(); MutexLock lock(&mutex_); ThreadIdToThreadLocals* const thread_to_thread_locals = GetThreadLocalsMapLocked(); ThreadIdToThreadLocals::iterator thread_local_pos = thread_to_thread_locals->find(current_thread); if (thread_local_pos == thread_to_thread_locals->end()) { thread_local_pos = thread_to_thread_locals->insert( std::make_pair(current_thread, ThreadLocalValues())).first; StartWatcherThreadFor(current_thread); } ThreadLocalValues& thread_local_values = thread_local_pos->second; ThreadLocalValues::iterator value_pos = thread_local_values.find(thread_local_instance); if (value_pos == thread_local_values.end()) { value_pos = thread_local_values .insert(std::make_pair( thread_local_instance, std::shared_ptr( thread_local_instance->NewValueForCurrentThread()))) .first; } return value_pos->second.get(); } static void OnThreadLocalDestroyed( const ThreadLocalBase* thread_local_instance) { std::vector > value_holders; // Clean up the ThreadLocalValues data structure while holding the lock, but // defer the destruction of the ThreadLocalValueHolderBases. { MutexLock lock(&mutex_); ThreadIdToThreadLocals* const thread_to_thread_locals = GetThreadLocalsMapLocked(); for (ThreadIdToThreadLocals::iterator it = thread_to_thread_locals->begin(); it != thread_to_thread_locals->end(); ++it) { ThreadLocalValues& thread_local_values = it->second; ThreadLocalValues::iterator value_pos = thread_local_values.find(thread_local_instance); if (value_pos != thread_local_values.end()) { value_holders.push_back(value_pos->second); thread_local_values.erase(value_pos); // This 'if' can only be successful at most once, so theoretically we // could break out of the loop here, but we don't bother doing so. } } } // Outside the lock, let the destructor for 'value_holders' deallocate the // ThreadLocalValueHolderBases. } static void OnThreadExit(DWORD thread_id) { GTEST_CHECK_(thread_id != 0) << ::GetLastError(); std::vector > value_holders; // Clean up the ThreadIdToThreadLocals data structure while holding the // lock, but defer the destruction of the ThreadLocalValueHolderBases. { MutexLock lock(&mutex_); ThreadIdToThreadLocals* const thread_to_thread_locals = GetThreadLocalsMapLocked(); ThreadIdToThreadLocals::iterator thread_local_pos = thread_to_thread_locals->find(thread_id); if (thread_local_pos != thread_to_thread_locals->end()) { ThreadLocalValues& thread_local_values = thread_local_pos->second; for (ThreadLocalValues::iterator value_pos = thread_local_values.begin(); value_pos != thread_local_values.end(); ++value_pos) { value_holders.push_back(value_pos->second); } thread_to_thread_locals->erase(thread_local_pos); } } // Outside the lock, let the destructor for 'value_holders' deallocate the // ThreadLocalValueHolderBases. } private: // In a particular thread, maps a ThreadLocal object to its value. typedef std::map > ThreadLocalValues; // Stores all ThreadIdToThreadLocals having values in a thread, indexed by // thread's ID. typedef std::map ThreadIdToThreadLocals; // Holds the thread id and thread handle that we pass from // StartWatcherThreadFor to WatcherThreadFunc. typedef std::pair ThreadIdAndHandle; static void StartWatcherThreadFor(DWORD thread_id) { // The returned handle will be kept in thread_map and closed by // watcher_thread in WatcherThreadFunc. HANDLE thread = ::OpenThread(SYNCHRONIZE | THREAD_QUERY_INFORMATION, FALSE, thread_id); GTEST_CHECK_(thread != nullptr); // We need to pass a valid thread ID pointer into CreateThread for it // to work correctly under Win98. DWORD watcher_thread_id; HANDLE watcher_thread = ::CreateThread( nullptr, // Default security. 0, // Default stack size &ThreadLocalRegistryImpl::WatcherThreadFunc, reinterpret_cast(new ThreadIdAndHandle(thread_id, thread)), CREATE_SUSPENDED, &watcher_thread_id); GTEST_CHECK_(watcher_thread != nullptr); // Give the watcher thread the same priority as ours to avoid being // blocked by it. ::SetThreadPriority(watcher_thread, ::GetThreadPriority(::GetCurrentThread())); ::ResumeThread(watcher_thread); ::CloseHandle(watcher_thread); } // Monitors exit from a given thread and notifies those // ThreadIdToThreadLocals about thread termination. static DWORD WINAPI WatcherThreadFunc(LPVOID param) { const ThreadIdAndHandle* tah = reinterpret_cast(param); GTEST_CHECK_( ::WaitForSingleObject(tah->second, INFINITE) == WAIT_OBJECT_0); OnThreadExit(tah->first); ::CloseHandle(tah->second); delete tah; return 0; } // Returns map of thread local instances. static ThreadIdToThreadLocals* GetThreadLocalsMapLocked() { mutex_.AssertHeld(); #ifdef _MSC_VER MemoryIsNotDeallocated memory_is_not_deallocated; #endif // _MSC_VER static ThreadIdToThreadLocals* map = new ThreadIdToThreadLocals(); return map; } // Protects access to GetThreadLocalsMapLocked() and its return value. static Mutex mutex_; // Protects access to GetThreadMapLocked() and its return value. static Mutex thread_map_mutex_; }; Mutex ThreadLocalRegistryImpl::mutex_(Mutex::kStaticMutex); Mutex ThreadLocalRegistryImpl::thread_map_mutex_(Mutex::kStaticMutex); ThreadLocalValueHolderBase* ThreadLocalRegistry::GetValueOnCurrentThread( const ThreadLocalBase* thread_local_instance) { return ThreadLocalRegistryImpl::GetValueOnCurrentThread( thread_local_instance); } void ThreadLocalRegistry::OnThreadLocalDestroyed( const ThreadLocalBase* thread_local_instance) { ThreadLocalRegistryImpl::OnThreadLocalDestroyed(thread_local_instance); } #endif // GTEST_IS_THREADSAFE && GTEST_OS_WINDOWS #if GTEST_USES_POSIX_RE // Implements RE. Currently only needed for death tests. RE::~RE() { if (is_valid_) { // regfree'ing an invalid regex might crash because the content // of the regex is undefined. Since the regex's are essentially // the same, one cannot be valid (or invalid) without the other // being so too. regfree(&partial_regex_); regfree(&full_regex_); } free(const_cast(pattern_)); } // Returns true if and only if regular expression re matches the entire str. bool RE::FullMatch(const char* str, const RE& re) { if (!re.is_valid_) return false; regmatch_t match; return regexec(&re.full_regex_, str, 1, &match, 0) == 0; } // Returns true if and only if regular expression re matches a substring of // str (including str itself). bool RE::PartialMatch(const char* str, const RE& re) { if (!re.is_valid_) return false; regmatch_t match; return regexec(&re.partial_regex_, str, 1, &match, 0) == 0; } // Initializes an RE from its string representation. void RE::Init(const char* regex) { pattern_ = posix::StrDup(regex); // Reserves enough bytes to hold the regular expression used for a // full match. const size_t full_regex_len = strlen(regex) + 10; char* const full_pattern = new char[full_regex_len]; snprintf(full_pattern, full_regex_len, "^(%s)$", regex); is_valid_ = regcomp(&full_regex_, full_pattern, REG_EXTENDED) == 0; // We want to call regcomp(&partial_regex_, ...) even if the // previous expression returns false. Otherwise partial_regex_ may // not be properly initialized can may cause trouble when it's // freed. // // Some implementation of POSIX regex (e.g. on at least some // versions of Cygwin) doesn't accept the empty string as a valid // regex. We change it to an equivalent form "()" to be safe. if (is_valid_) { const char* const partial_regex = (*regex == '\0') ? "()" : regex; is_valid_ = regcomp(&partial_regex_, partial_regex, REG_EXTENDED) == 0; } EXPECT_TRUE(is_valid_) << "Regular expression \"" << regex << "\" is not a valid POSIX Extended regular expression."; delete[] full_pattern; } #elif GTEST_USES_SIMPLE_RE // Returns true if and only if ch appears anywhere in str (excluding the // terminating '\0' character). bool IsInSet(char ch, const char* str) { return ch != '\0' && strchr(str, ch) != nullptr; } // Returns true if and only if ch belongs to the given classification. // Unlike similar functions in , these aren't affected by the // current locale. bool IsAsciiDigit(char ch) { return '0' <= ch && ch <= '9'; } bool IsAsciiPunct(char ch) { return IsInSet(ch, "^-!\"#$%&'()*+,./:;<=>?@[\\]_`{|}~"); } bool IsRepeat(char ch) { return IsInSet(ch, "?*+"); } bool IsAsciiWhiteSpace(char ch) { return IsInSet(ch, " \f\n\r\t\v"); } bool IsAsciiWordChar(char ch) { return ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ('0' <= ch && ch <= '9') || ch == '_'; } // Returns true if and only if "\\c" is a supported escape sequence. bool IsValidEscape(char c) { return (IsAsciiPunct(c) || IsInSet(c, "dDfnrsStvwW")); } // Returns true if and only if the given atom (specified by escaped and // pattern) matches ch. The result is undefined if the atom is invalid. bool AtomMatchesChar(bool escaped, char pattern_char, char ch) { if (escaped) { // "\\p" where p is pattern_char. switch (pattern_char) { case 'd': return IsAsciiDigit(ch); case 'D': return !IsAsciiDigit(ch); case 'f': return ch == '\f'; case 'n': return ch == '\n'; case 'r': return ch == '\r'; case 's': return IsAsciiWhiteSpace(ch); case 'S': return !IsAsciiWhiteSpace(ch); case 't': return ch == '\t'; case 'v': return ch == '\v'; case 'w': return IsAsciiWordChar(ch); case 'W': return !IsAsciiWordChar(ch); } return IsAsciiPunct(pattern_char) && pattern_char == ch; } return (pattern_char == '.' && ch != '\n') || pattern_char == ch; } // Helper function used by ValidateRegex() to format error messages. static std::string FormatRegexSyntaxError(const char* regex, int index) { return (Message() << "Syntax error at index " << index << " in simple regular expression \"" << regex << "\": ").GetString(); } // Generates non-fatal failures and returns false if regex is invalid; // otherwise returns true. bool ValidateRegex(const char* regex) { if (regex == nullptr) { ADD_FAILURE() << "NULL is not a valid simple regular expression."; return false; } bool is_valid = true; // True if and only if ?, *, or + can follow the previous atom. bool prev_repeatable = false; for (int i = 0; regex[i]; i++) { if (regex[i] == '\\') { // An escape sequence i++; if (regex[i] == '\0') { ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) << "'\\' cannot appear at the end."; return false; } if (!IsValidEscape(regex[i])) { ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) << "invalid escape sequence \"\\" << regex[i] << "\"."; is_valid = false; } prev_repeatable = true; } else { // Not an escape sequence. const char ch = regex[i]; if (ch == '^' && i > 0) { ADD_FAILURE() << FormatRegexSyntaxError(regex, i) << "'^' can only appear at the beginning."; is_valid = false; } else if (ch == '$' && regex[i + 1] != '\0') { ADD_FAILURE() << FormatRegexSyntaxError(regex, i) << "'$' can only appear at the end."; is_valid = false; } else if (IsInSet(ch, "()[]{}|")) { ADD_FAILURE() << FormatRegexSyntaxError(regex, i) << "'" << ch << "' is unsupported."; is_valid = false; } else if (IsRepeat(ch) && !prev_repeatable) { ADD_FAILURE() << FormatRegexSyntaxError(regex, i) << "'" << ch << "' can only follow a repeatable token."; is_valid = false; } prev_repeatable = !IsInSet(ch, "^$?*+"); } } return is_valid; } // Matches a repeated regex atom followed by a valid simple regular // expression. The regex atom is defined as c if escaped is false, // or \c otherwise. repeat is the repetition meta character (?, *, // or +). The behavior is undefined if str contains too many // characters to be indexable by size_t, in which case the test will // probably time out anyway. We are fine with this limitation as // std::string has it too. bool MatchRepetitionAndRegexAtHead( bool escaped, char c, char repeat, const char* regex, const char* str) { const size_t min_count = (repeat == '+') ? 1 : 0; const size_t max_count = (repeat == '?') ? 1 : static_cast(-1) - 1; // We cannot call numeric_limits::max() as it conflicts with the // max() macro on Windows. for (size_t i = 0; i <= max_count; ++i) { // We know that the atom matches each of the first i characters in str. if (i >= min_count && MatchRegexAtHead(regex, str + i)) { // We have enough matches at the head, and the tail matches too. // Since we only care about *whether* the pattern matches str // (as opposed to *how* it matches), there is no need to find a // greedy match. return true; } if (str[i] == '\0' || !AtomMatchesChar(escaped, c, str[i])) return false; } return false; } // Returns true if and only if regex matches a prefix of str. regex must // be a valid simple regular expression and not start with "^", or the // result is undefined. bool MatchRegexAtHead(const char* regex, const char* str) { if (*regex == '\0') // An empty regex matches a prefix of anything. return true; // "$" only matches the end of a string. Note that regex being // valid guarantees that there's nothing after "$" in it. if (*regex == '$') return *str == '\0'; // Is the first thing in regex an escape sequence? const bool escaped = *regex == '\\'; if (escaped) ++regex; if (IsRepeat(regex[1])) { // MatchRepetitionAndRegexAtHead() calls MatchRegexAtHead(), so // here's an indirect recursion. It terminates as the regex gets // shorter in each recursion. return MatchRepetitionAndRegexAtHead( escaped, regex[0], regex[1], regex + 2, str); } else { // regex isn't empty, isn't "$", and doesn't start with a // repetition. We match the first atom of regex with the first // character of str and recurse. return (*str != '\0') && AtomMatchesChar(escaped, *regex, *str) && MatchRegexAtHead(regex + 1, str + 1); } } // Returns true if and only if regex matches any substring of str. regex must // be a valid simple regular expression, or the result is undefined. // // The algorithm is recursive, but the recursion depth doesn't exceed // the regex length, so we won't need to worry about running out of // stack space normally. In rare cases the time complexity can be // exponential with respect to the regex length + the string length, // but usually it's must faster (often close to linear). bool MatchRegexAnywhere(const char* regex, const char* str) { if (regex == nullptr || str == nullptr) return false; if (*regex == '^') return MatchRegexAtHead(regex + 1, str); // A successful match can be anywhere in str. do { if (MatchRegexAtHead(regex, str)) return true; } while (*str++ != '\0'); return false; } // Implements the RE class. RE::~RE() { free(const_cast(pattern_)); free(const_cast(full_pattern_)); } // Returns true if and only if regular expression re matches the entire str. bool RE::FullMatch(const char* str, const RE& re) { return re.is_valid_ && MatchRegexAnywhere(re.full_pattern_, str); } // Returns true if and only if regular expression re matches a substring of // str (including str itself). bool RE::PartialMatch(const char* str, const RE& re) { return re.is_valid_ && MatchRegexAnywhere(re.pattern_, str); } // Initializes an RE from its string representation. void RE::Init(const char* regex) { pattern_ = full_pattern_ = nullptr; if (regex != nullptr) { pattern_ = posix::StrDup(regex); } is_valid_ = ValidateRegex(regex); if (!is_valid_) { // No need to calculate the full pattern when the regex is invalid. return; } const size_t len = strlen(regex); // Reserves enough bytes to hold the regular expression used for a // full match: we need space to prepend a '^', append a '$', and // terminate the string with '\0'. char* buffer = static_cast(malloc(len + 3)); full_pattern_ = buffer; if (*regex != '^') *buffer++ = '^'; // Makes sure full_pattern_ starts with '^'. // We don't use snprintf or strncpy, as they trigger a warning when // compiled with VC++ 8.0. memcpy(buffer, regex, len); buffer += len; if (len == 0 || regex[len - 1] != '$') *buffer++ = '$'; // Makes sure full_pattern_ ends with '$'. *buffer = '\0'; } #endif // GTEST_USES_POSIX_RE const char kUnknownFile[] = "unknown file"; // Formats a source file path and a line number as they would appear // in an error message from the compiler used to compile this code. GTEST_API_ ::std::string FormatFileLocation(const char* file, int line) { const std::string file_name(file == nullptr ? kUnknownFile : file); if (line < 0) { return file_name + ":"; } #ifdef _MSC_VER return file_name + "(" + StreamableToString(line) + "):"; #else return file_name + ":" + StreamableToString(line) + ":"; #endif // _MSC_VER } // Formats a file location for compiler-independent XML output. // Although this function is not platform dependent, we put it next to // FormatFileLocation in order to contrast the two functions. // Note that FormatCompilerIndependentFileLocation() does NOT append colon // to the file location it produces, unlike FormatFileLocation(). GTEST_API_ ::std::string FormatCompilerIndependentFileLocation( const char* file, int line) { const std::string file_name(file == nullptr ? kUnknownFile : file); if (line < 0) return file_name; else return file_name + ":" + StreamableToString(line); } GTestLog::GTestLog(GTestLogSeverity severity, const char* file, int line) : severity_(severity) { const char* const marker = severity == GTEST_INFO ? "[ INFO ]" : severity == GTEST_WARNING ? "[WARNING]" : severity == GTEST_ERROR ? "[ ERROR ]" : "[ FATAL ]"; GetStream() << ::std::endl << marker << " " << FormatFileLocation(file, line).c_str() << ": "; } // Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. GTestLog::~GTestLog() { GetStream() << ::std::endl; if (severity_ == GTEST_FATAL) { fflush(stderr); posix::Abort(); } } // Disable Microsoft deprecation warnings for POSIX functions called from // this class (creat, dup, dup2, and close) GTEST_DISABLE_MSC_DEPRECATED_PUSH_() #if GTEST_HAS_STREAM_REDIRECTION // Object that captures an output stream (stdout/stderr). class CapturedStream { public: // The ctor redirects the stream to a temporary file. explicit CapturedStream(int fd) : fd_(fd), uncaptured_fd_(dup(fd)) { # if GTEST_OS_WINDOWS char temp_dir_path[MAX_PATH + 1] = { '\0' }; // NOLINT char temp_file_path[MAX_PATH + 1] = { '\0' }; // NOLINT ::GetTempPathA(sizeof(temp_dir_path), temp_dir_path); const UINT success = ::GetTempFileNameA(temp_dir_path, "gtest_redir", 0, // Generate unique file name. temp_file_path); GTEST_CHECK_(success != 0) << "Unable to create a temporary file in " << temp_dir_path; const int captured_fd = creat(temp_file_path, _S_IREAD | _S_IWRITE); GTEST_CHECK_(captured_fd != -1) << "Unable to open temporary file " << temp_file_path; filename_ = temp_file_path; # else // There's no guarantee that a test has write access to the current // directory, so we create the temporary file in the /tmp directory // instead. We use /tmp on most systems, and /sdcard on Android. // That's because Android doesn't have /tmp. # if GTEST_OS_LINUX_ANDROID // Note: Android applications are expected to call the framework's // Context.getExternalStorageDirectory() method through JNI to get // the location of the world-writable SD Card directory. However, // this requires a Context handle, which cannot be retrieved // globally from native code. Doing so also precludes running the // code as part of a regular standalone executable, which doesn't // run in a Dalvik process (e.g. when running it through 'adb shell'). // // The location /data/local/tmp is directly accessible from native code. // '/sdcard' and other variants cannot be relied on, as they are not // guaranteed to be mounted, or may have a delay in mounting. char name_template[] = "/data/local/tmp/gtest_captured_stream.XXXXXX"; # else char name_template[] = "/tmp/captured_stream.XXXXXX"; # endif // GTEST_OS_LINUX_ANDROID const int captured_fd = mkstemp(name_template); if (captured_fd == -1) { GTEST_LOG_(WARNING) << "Failed to create tmp file " << name_template << " for test; does the test have access to the /tmp directory?"; } filename_ = name_template; # endif // GTEST_OS_WINDOWS fflush(nullptr); dup2(captured_fd, fd_); close(captured_fd); } ~CapturedStream() { remove(filename_.c_str()); } std::string GetCapturedString() { if (uncaptured_fd_ != -1) { // Restores the original stream. fflush(nullptr); dup2(uncaptured_fd_, fd_); close(uncaptured_fd_); uncaptured_fd_ = -1; } FILE* const file = posix::FOpen(filename_.c_str(), "r"); if (file == nullptr) { GTEST_LOG_(FATAL) << "Failed to open tmp file " << filename_ << " for capturing stream."; } const std::string content = ReadEntireFile(file); posix::FClose(file); return content; } private: const int fd_; // A stream to capture. int uncaptured_fd_; // Name of the temporary file holding the stderr output. ::std::string filename_; GTEST_DISALLOW_COPY_AND_ASSIGN_(CapturedStream); }; GTEST_DISABLE_MSC_DEPRECATED_POP_() static CapturedStream* g_captured_stderr = nullptr; static CapturedStream* g_captured_stdout = nullptr; // Starts capturing an output stream (stdout/stderr). static void CaptureStream(int fd, const char* stream_name, CapturedStream** stream) { if (*stream != nullptr) { GTEST_LOG_(FATAL) << "Only one " << stream_name << " capturer can exist at a time."; } *stream = new CapturedStream(fd); } // Stops capturing the output stream and returns the captured string. static std::string GetCapturedStream(CapturedStream** captured_stream) { const std::string content = (*captured_stream)->GetCapturedString(); delete *captured_stream; *captured_stream = nullptr; return content; } // Starts capturing stdout. void CaptureStdout() { CaptureStream(kStdOutFileno, "stdout", &g_captured_stdout); } // Starts capturing stderr. void CaptureStderr() { CaptureStream(kStdErrFileno, "stderr", &g_captured_stderr); } // Stops capturing stdout and returns the captured string. std::string GetCapturedStdout() { return GetCapturedStream(&g_captured_stdout); } // Stops capturing stderr and returns the captured string. std::string GetCapturedStderr() { return GetCapturedStream(&g_captured_stderr); } #endif // GTEST_HAS_STREAM_REDIRECTION size_t GetFileSize(FILE* file) { fseek(file, 0, SEEK_END); return static_cast(ftell(file)); } std::string ReadEntireFile(FILE* file) { const size_t file_size = GetFileSize(file); char* const buffer = new char[file_size]; size_t bytes_last_read = 0; // # of bytes read in the last fread() size_t bytes_read = 0; // # of bytes read so far fseek(file, 0, SEEK_SET); // Keeps reading the file until we cannot read further or the // pre-determined file size is reached. do { bytes_last_read = fread(buffer+bytes_read, 1, file_size-bytes_read, file); bytes_read += bytes_last_read; } while (bytes_last_read > 0 && bytes_read < file_size); const std::string content(buffer, bytes_read); delete[] buffer; return content; } #if GTEST_HAS_DEATH_TEST static const std::vector* g_injected_test_argvs = nullptr; // Owned. std::vector GetInjectableArgvs() { if (g_injected_test_argvs != nullptr) { return *g_injected_test_argvs; } return GetArgvs(); } void SetInjectableArgvs(const std::vector* new_argvs) { if (g_injected_test_argvs != new_argvs) delete g_injected_test_argvs; g_injected_test_argvs = new_argvs; } void SetInjectableArgvs(const std::vector& new_argvs) { SetInjectableArgvs( new std::vector(new_argvs.begin(), new_argvs.end())); } void ClearInjectableArgvs() { delete g_injected_test_argvs; g_injected_test_argvs = nullptr; } #endif // GTEST_HAS_DEATH_TEST #if GTEST_OS_WINDOWS_MOBILE namespace posix { void Abort() { DebugBreak(); TerminateProcess(GetCurrentProcess(), 1); } } // namespace posix #endif // GTEST_OS_WINDOWS_MOBILE // Returns the name of the environment variable corresponding to the // given flag. For example, FlagToEnvVar("foo") will return // "GTEST_FOO" in the open-source version. static std::string FlagToEnvVar(const char* flag) { const std::string full_flag = (Message() << GTEST_FLAG_PREFIX_ << flag).GetString(); Message env_var; for (size_t i = 0; i != full_flag.length(); i++) { env_var << ToUpper(full_flag.c_str()[i]); } return env_var.GetString(); } // Parses 'str' for a 32-bit signed integer. If successful, writes // the result to *value and returns true; otherwise leaves *value // unchanged and returns false. bool ParseInt32(const Message& src_text, const char* str, Int32* value) { // Parses the environment variable as a decimal integer. char* end = nullptr; const long long_value = strtol(str, &end, 10); // NOLINT // Has strtol() consumed all characters in the string? if (*end != '\0') { // No - an invalid character was encountered. Message msg; msg << "WARNING: " << src_text << " is expected to be a 32-bit integer, but actually" << " has value \"" << str << "\".\n"; printf("%s", msg.GetString().c_str()); fflush(stdout); return false; } // Is the parsed value in the range of an Int32? const Int32 result = static_cast(long_value); if (long_value == LONG_MAX || long_value == LONG_MIN || // The parsed value overflows as a long. (strtol() returns // LONG_MAX or LONG_MIN when the input overflows.) result != long_value // The parsed value overflows as an Int32. ) { Message msg; msg << "WARNING: " << src_text << " is expected to be a 32-bit integer, but actually" << " has value " << str << ", which overflows.\n"; printf("%s", msg.GetString().c_str()); fflush(stdout); return false; } *value = result; return true; } // Reads and returns the Boolean environment variable corresponding to // the given flag; if it's not set, returns default_value. // // The value is considered true if and only if it's not "0". bool BoolFromGTestEnv(const char* flag, bool default_value) { #if defined(GTEST_GET_BOOL_FROM_ENV_) return GTEST_GET_BOOL_FROM_ENV_(flag, default_value); #else const std::string env_var = FlagToEnvVar(flag); const char* const string_value = posix::GetEnv(env_var.c_str()); return string_value == nullptr ? default_value : strcmp(string_value, "0") != 0; #endif // defined(GTEST_GET_BOOL_FROM_ENV_) } // Reads and returns a 32-bit integer stored in the environment // variable corresponding to the given flag; if it isn't set or // doesn't represent a valid 32-bit integer, returns default_value. Int32 Int32FromGTestEnv(const char* flag, Int32 default_value) { #if defined(GTEST_GET_INT32_FROM_ENV_) return GTEST_GET_INT32_FROM_ENV_(flag, default_value); #else const std::string env_var = FlagToEnvVar(flag); const char* const string_value = posix::GetEnv(env_var.c_str()); if (string_value == nullptr) { // The environment variable is not set. return default_value; } Int32 result = default_value; if (!ParseInt32(Message() << "Environment variable " << env_var, string_value, &result)) { printf("The default value %s is used.\n", (Message() << default_value).GetString().c_str()); fflush(stdout); return default_value; } return result; #endif // defined(GTEST_GET_INT32_FROM_ENV_) } // As a special case for the 'output' flag, if GTEST_OUTPUT is not // set, we look for XML_OUTPUT_FILE, which is set by the Bazel build // system. The value of XML_OUTPUT_FILE is a filename without the // "xml:" prefix of GTEST_OUTPUT. // Note that this is meant to be called at the call site so it does // not check that the flag is 'output' // In essence this checks an env variable called XML_OUTPUT_FILE // and if it is set we prepend "xml:" to its value, if it not set we return "" std::string OutputFlagAlsoCheckEnvVar(){ std::string default_value_for_output_flag = ""; const char* xml_output_file_env = posix::GetEnv("XML_OUTPUT_FILE"); if (nullptr != xml_output_file_env) { default_value_for_output_flag = std::string("xml:") + xml_output_file_env; } return default_value_for_output_flag; } // Reads and returns the string environment variable corresponding to // the given flag; if it's not set, returns default_value. const char* StringFromGTestEnv(const char* flag, const char* default_value) { #if defined(GTEST_GET_STRING_FROM_ENV_) return GTEST_GET_STRING_FROM_ENV_(flag, default_value); #else const std::string env_var = FlagToEnvVar(flag); const char* const value = posix::GetEnv(env_var.c_str()); return value == nullptr ? default_value : value; #endif // defined(GTEST_GET_STRING_FROM_ENV_) } } // namespace internal } // namespace testing // Copyright 2007, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Google Test - The Google C++ Testing and Mocking Framework // // This file implements a universal value printer that can print a // value of any type T: // // void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); // // It uses the << operator when possible, and prints the bytes in the // object otherwise. A user can override its behavior for a class // type Foo by defining either operator<<(::std::ostream&, const Foo&) // or void PrintTo(const Foo&, ::std::ostream*) in the namespace that // defines Foo. #include #include #include #include // NOLINT #include namespace testing { namespace { using ::std::ostream; // Prints a segment of bytes in the given object. GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ void PrintByteSegmentInObjectTo(const unsigned char* obj_bytes, size_t start, size_t count, ostream* os) { char text[5] = ""; for (size_t i = 0; i != count; i++) { const size_t j = start + i; if (i != 0) { // Organizes the bytes into groups of 2 for easy parsing by // human. if ((j % 2) == 0) *os << ' '; else *os << '-'; } GTEST_SNPRINTF_(text, sizeof(text), "%02X", obj_bytes[j]); *os << text; } } // Prints the bytes in the given value to the given ostream. void PrintBytesInObjectToImpl(const unsigned char* obj_bytes, size_t count, ostream* os) { // Tells the user how big the object is. *os << count << "-byte object <"; const size_t kThreshold = 132; const size_t kChunkSize = 64; // If the object size is bigger than kThreshold, we'll have to omit // some details by printing only the first and the last kChunkSize // bytes. if (count < kThreshold) { PrintByteSegmentInObjectTo(obj_bytes, 0, count, os); } else { PrintByteSegmentInObjectTo(obj_bytes, 0, kChunkSize, os); *os << " ... "; // Rounds up to 2-byte boundary. const size_t resume_pos = (count - kChunkSize + 1)/2*2; PrintByteSegmentInObjectTo(obj_bytes, resume_pos, count - resume_pos, os); } *os << ">"; } } // namespace namespace internal2 { // Delegates to PrintBytesInObjectToImpl() to print the bytes in the // given object. The delegation simplifies the implementation, which // uses the << operator and thus is easier done outside of the // ::testing::internal namespace, which contains a << operator that // sometimes conflicts with the one in STL. void PrintBytesInObjectTo(const unsigned char* obj_bytes, size_t count, ostream* os) { PrintBytesInObjectToImpl(obj_bytes, count, os); } } // namespace internal2 namespace internal { // Depending on the value of a char (or wchar_t), we print it in one // of three formats: // - as is if it's a printable ASCII (e.g. 'a', '2', ' '), // - as a hexadecimal escape sequence (e.g. '\x7F'), or // - as a special escape sequence (e.g. '\r', '\n'). enum CharFormat { kAsIs, kHexEscape, kSpecialEscape }; // Returns true if c is a printable ASCII character. We test the // value of c directly instead of calling isprint(), which is buggy on // Windows Mobile. inline bool IsPrintableAscii(wchar_t c) { return 0x20 <= c && c <= 0x7E; } // Prints a wide or narrow char c as a character literal without the // quotes, escaping it when necessary; returns how c was formatted. // The template argument UnsignedChar is the unsigned version of Char, // which is the type of c. template static CharFormat PrintAsCharLiteralTo(Char c, ostream* os) { wchar_t w_c = static_cast(c); switch (w_c) { case L'\0': *os << "\\0"; break; case L'\'': *os << "\\'"; break; case L'\\': *os << "\\\\"; break; case L'\a': *os << "\\a"; break; case L'\b': *os << "\\b"; break; case L'\f': *os << "\\f"; break; case L'\n': *os << "\\n"; break; case L'\r': *os << "\\r"; break; case L'\t': *os << "\\t"; break; case L'\v': *os << "\\v"; break; default: if (IsPrintableAscii(w_c)) { *os << static_cast(c); return kAsIs; } else { ostream::fmtflags flags = os->flags(); *os << "\\x" << std::hex << std::uppercase << static_cast(static_cast(c)); os->flags(flags); return kHexEscape; } } return kSpecialEscape; } // Prints a wchar_t c as if it's part of a string literal, escaping it when // necessary; returns how c was formatted. static CharFormat PrintAsStringLiteralTo(wchar_t c, ostream* os) { switch (c) { case L'\'': *os << "'"; return kAsIs; case L'"': *os << "\\\""; return kSpecialEscape; default: return PrintAsCharLiteralTo(c, os); } } // Prints a char c as if it's part of a string literal, escaping it when // necessary; returns how c was formatted. static CharFormat PrintAsStringLiteralTo(char c, ostream* os) { return PrintAsStringLiteralTo( static_cast(static_cast(c)), os); } // Prints a wide or narrow character c and its code. '\0' is printed // as "'\\0'", other unprintable characters are also properly escaped // using the standard C++ escape sequence. The template argument // UnsignedChar is the unsigned version of Char, which is the type of c. template void PrintCharAndCodeTo(Char c, ostream* os) { // First, print c as a literal in the most readable form we can find. *os << ((sizeof(c) > 1) ? "L'" : "'"); const CharFormat format = PrintAsCharLiteralTo(c, os); *os << "'"; // To aid user debugging, we also print c's code in decimal, unless // it's 0 (in which case c was printed as '\\0', making the code // obvious). if (c == 0) return; *os << " (" << static_cast(c); // For more convenience, we print c's code again in hexadecimal, // unless c was already printed in the form '\x##' or the code is in // [1, 9]. if (format == kHexEscape || (1 <= c && c <= 9)) { // Do nothing. } else { *os << ", 0x" << String::FormatHexInt(static_cast(c)); } *os << ")"; } void PrintTo(unsigned char c, ::std::ostream* os) { PrintCharAndCodeTo(c, os); } void PrintTo(signed char c, ::std::ostream* os) { PrintCharAndCodeTo(c, os); } // Prints a wchar_t as a symbol if it is printable or as its internal // code otherwise and also as its code. L'\0' is printed as "L'\\0'". void PrintTo(wchar_t wc, ostream* os) { PrintCharAndCodeTo(wc, os); } // Prints the given array of characters to the ostream. CharType must be either // char or wchar_t. // The array starts at begin, the length is len, it may include '\0' characters // and may not be NUL-terminated. template GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ static CharFormat PrintCharsAsStringTo( const CharType* begin, size_t len, ostream* os) { const char* const kQuoteBegin = sizeof(CharType) == 1 ? "\"" : "L\""; *os << kQuoteBegin; bool is_previous_hex = false; CharFormat print_format = kAsIs; for (size_t index = 0; index < len; ++index) { const CharType cur = begin[index]; if (is_previous_hex && IsXDigit(cur)) { // Previous character is of '\x..' form and this character can be // interpreted as another hexadecimal digit in its number. Break string to // disambiguate. *os << "\" " << kQuoteBegin; } is_previous_hex = PrintAsStringLiteralTo(cur, os) == kHexEscape; // Remember if any characters required hex escaping. if (is_previous_hex) { print_format = kHexEscape; } } *os << "\""; return print_format; } // Prints a (const) char/wchar_t array of 'len' elements, starting at address // 'begin'. CharType must be either char or wchar_t. template GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ static void UniversalPrintCharArray( const CharType* begin, size_t len, ostream* os) { // The code // const char kFoo[] = "foo"; // generates an array of 4, not 3, elements, with the last one being '\0'. // // Therefore when printing a char array, we don't print the last element if // it's '\0', such that the output matches the string literal as it's // written in the source code. if (len > 0 && begin[len - 1] == '\0') { PrintCharsAsStringTo(begin, len - 1, os); return; } // If, however, the last element in the array is not '\0', e.g. // const char kFoo[] = { 'f', 'o', 'o' }; // we must print the entire array. We also print a message to indicate // that the array is not NUL-terminated. PrintCharsAsStringTo(begin, len, os); *os << " (no terminating NUL)"; } // Prints a (const) char array of 'len' elements, starting at address 'begin'. void UniversalPrintArray(const char* begin, size_t len, ostream* os) { UniversalPrintCharArray(begin, len, os); } // Prints a (const) wchar_t array of 'len' elements, starting at address // 'begin'. void UniversalPrintArray(const wchar_t* begin, size_t len, ostream* os) { UniversalPrintCharArray(begin, len, os); } // Prints the given C string to the ostream. void PrintTo(const char* s, ostream* os) { if (s == nullptr) { *os << "NULL"; } else { *os << ImplicitCast_(s) << " pointing to "; PrintCharsAsStringTo(s, strlen(s), os); } } // MSVC compiler can be configured to define whar_t as a typedef // of unsigned short. Defining an overload for const wchar_t* in that case // would cause pointers to unsigned shorts be printed as wide strings, // possibly accessing more memory than intended and causing invalid // memory accesses. MSVC defines _NATIVE_WCHAR_T_DEFINED symbol when // wchar_t is implemented as a native type. #if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) // Prints the given wide C string to the ostream. void PrintTo(const wchar_t* s, ostream* os) { if (s == nullptr) { *os << "NULL"; } else { *os << ImplicitCast_(s) << " pointing to "; PrintCharsAsStringTo(s, wcslen(s), os); } } #endif // wchar_t is native namespace { bool ContainsUnprintableControlCodes(const char* str, size_t length) { const unsigned char *s = reinterpret_cast(str); for (size_t i = 0; i < length; i++) { unsigned char ch = *s++; if (std::iscntrl(ch)) { switch (ch) { case '\t': case '\n': case '\r': break; default: return true; } } } return false; } bool IsUTF8TrailByte(unsigned char t) { return 0x80 <= t && t<= 0xbf; } bool IsValidUTF8(const char* str, size_t length) { const unsigned char *s = reinterpret_cast(str); for (size_t i = 0; i < length;) { unsigned char lead = s[i++]; if (lead <= 0x7f) { continue; // single-byte character (ASCII) 0..7F } if (lead < 0xc2) { return false; // trail byte or non-shortest form } else if (lead <= 0xdf && (i + 1) <= length && IsUTF8TrailByte(s[i])) { ++i; // 2-byte character } else if (0xe0 <= lead && lead <= 0xef && (i + 2) <= length && IsUTF8TrailByte(s[i]) && IsUTF8TrailByte(s[i + 1]) && // check for non-shortest form and surrogate (lead != 0xe0 || s[i] >= 0xa0) && (lead != 0xed || s[i] < 0xa0)) { i += 2; // 3-byte character } else if (0xf0 <= lead && lead <= 0xf4 && (i + 3) <= length && IsUTF8TrailByte(s[i]) && IsUTF8TrailByte(s[i + 1]) && IsUTF8TrailByte(s[i + 2]) && // check for non-shortest form (lead != 0xf0 || s[i] >= 0x90) && (lead != 0xf4 || s[i] < 0x90)) { i += 3; // 4-byte character } else { return false; } } return true; } void ConditionalPrintAsText(const char* str, size_t length, ostream* os) { if (!ContainsUnprintableControlCodes(str, length) && IsValidUTF8(str, length)) { *os << "\n As Text: \"" << str << "\""; } } } // anonymous namespace void PrintStringTo(const ::std::string& s, ostream* os) { if (PrintCharsAsStringTo(s.data(), s.size(), os) == kHexEscape) { if (GTEST_FLAG(print_utf8)) { ConditionalPrintAsText(s.data(), s.size(), os); } } } #if GTEST_HAS_STD_WSTRING void PrintWideStringTo(const ::std::wstring& s, ostream* os) { PrintCharsAsStringTo(s.data(), s.size(), os); } #endif // GTEST_HAS_STD_WSTRING } // namespace internal } // namespace testing // Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) namespace testing { using internal::GetUnitTestImpl; // Gets the summary of the failure message by omitting the stack trace // in it. std::string TestPartResult::ExtractSummary(const char* message) { const char* const stack_trace = strstr(message, internal::kStackTraceMarker); return stack_trace == nullptr ? message : std::string(message, stack_trace); } // Prints a TestPartResult object. std::ostream& operator<<(std::ostream& os, const TestPartResult& result) { return os << result.file_name() << ":" << result.line_number() << ": " << (result.type() == TestPartResult::kSuccess ? "Success" : result.type() == TestPartResult::kSkip ? "Skipped" : result.type() == TestPartResult::kFatalFailure ? "Fatal failure" : "Non-fatal failure") << ":\n" << result.message() << std::endl; } // Appends a TestPartResult to the array. void TestPartResultArray::Append(const TestPartResult& result) { array_.push_back(result); } // Returns the TestPartResult at the given index (0-based). const TestPartResult& TestPartResultArray::GetTestPartResult(int index) const { if (index < 0 || index >= size()) { printf("\nInvalid index (%d) into TestPartResultArray.\n", index); internal::posix::Abort(); } return array_[static_cast(index)]; } // Returns the number of TestPartResult objects in the array. int TestPartResultArray::size() const { return static_cast(array_.size()); } namespace internal { HasNewFatalFailureHelper::HasNewFatalFailureHelper() : has_new_fatal_failure_(false), original_reporter_(GetUnitTestImpl()-> GetTestPartResultReporterForCurrentThread()) { GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread(this); } HasNewFatalFailureHelper::~HasNewFatalFailureHelper() { GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread( original_reporter_); } void HasNewFatalFailureHelper::ReportTestPartResult( const TestPartResult& result) { if (result.fatally_failed()) has_new_fatal_failure_ = true; original_reporter_->ReportTestPartResult(result); } } // namespace internal } // namespace testing // Copyright 2008 Google Inc. // All Rights Reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace testing { namespace internal { #if GTEST_HAS_TYPED_TEST_P // Skips to the first non-space char in str. Returns an empty string if str // contains only whitespace characters. static const char* SkipSpaces(const char* str) { while (IsSpace(*str)) str++; return str; } static std::vector SplitIntoTestNames(const char* src) { std::vector name_vec; src = SkipSpaces(src); for (; src != nullptr; src = SkipComma(src)) { name_vec.push_back(StripTrailingSpaces(GetPrefixUntilComma(src))); } return name_vec; } // Verifies that registered_tests match the test names in // registered_tests_; returns registered_tests if successful, or // aborts the program otherwise. const char* TypedTestSuitePState::VerifyRegisteredTestNames( const char* file, int line, const char* registered_tests) { typedef RegisteredTestsMap::const_iterator RegisteredTestIter; registered_ = true; std::vector name_vec = SplitIntoTestNames(registered_tests); Message errors; std::set tests; for (std::vector::const_iterator name_it = name_vec.begin(); name_it != name_vec.end(); ++name_it) { const std::string& name = *name_it; if (tests.count(name) != 0) { errors << "Test " << name << " is listed more than once.\n"; continue; } bool found = false; for (RegisteredTestIter it = registered_tests_.begin(); it != registered_tests_.end(); ++it) { if (name == it->first) { found = true; break; } } if (found) { tests.insert(name); } else { errors << "No test named " << name << " can be found in this test suite.\n"; } } for (RegisteredTestIter it = registered_tests_.begin(); it != registered_tests_.end(); ++it) { if (tests.count(it->first) == 0) { errors << "You forgot to list test " << it->first << ".\n"; } } const std::string& errors_str = errors.GetString(); if (errors_str != "") { fprintf(stderr, "%s %s", FormatFileLocation(file, line).c_str(), errors_str.c_str()); fflush(stderr); posix::Abort(); } return registered_tests; } #endif // GTEST_HAS_TYPED_TEST_P } // namespace internal } // namespace testing openucx-ucc-ec0bc8a/test/gtest/common/main.cc0000664000175000017500000000112015133731560021524 0ustar alastairalastair/** * Copyright (c) 2001-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (C) UT-Battelle, LLC. 2014. ALL RIGHTS RESERVED. * Copyright (C) Huawei Technologies Co., Ltd. 2020. All rights reserved. * See file LICENSE for terms. */ #ifdef HAVE_CONFIG_H # include "config.h" #endif #ifdef HAVE_CUDA #include #endif #include "test_ucc.h" int main(int argc, char **argv) { int ret; #ifdef HAVE_CUDA cudaSetDevice(0); #endif ::testing::InitGoogleTest(&argc, argv); ret = RUN_ALL_TESTS(); UccJob::cleanup(); return ret; } openucx-ucc-ec0bc8a/test/gtest/common/test_obj_size.cc0000664000175000017500000000102215133731560023444 0ustar alastairalastair/** * Copyright (c) 2001-2019, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (C) Huawei Technologies Co., Ltd. 2020. All rights reserved. * * See file LICENSE for terms. */ #ifdef HAVE_CONFIG_H # include "config.h" #endif #include extern "C" { #include } class test_obj_size : public ucc::test { }; UCC_TEST_F(test_obj_size, size) { /* lets try to keep it within 8 cache lines currently 480b */ EXPECT_LT(sizeof(ucc_coll_task_t), 64 * 8); } openucx-ucc-ec0bc8a/test/gtest/common/test_ucc.cc0000664000175000017500000005327115133731560022427 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_ucc.h" extern "C" { #include "core/ucc_team.h" #include "components/tl/ucc_tl.h" } constexpr ucc_lib_params_t UccProcess::default_lib_params; constexpr ucc_context_params_t UccProcess::default_ctx_params; constexpr int UccJob::staticTeamSizes[]; UccProcess::UccProcess(int _job_rank, const ucc_lib_params_t &lib_params, const ucc_context_params_t &_ctx_params) { ucc_lib_config_h lib_config; ucc_status_t status; std::stringstream err_msg; job_rank = _job_rank; ctx_params = _ctx_params; status = ucc_lib_config_read(NULL, NULL, &lib_config); if (status != UCC_OK) { err_msg << "ucc_lib_config_read failed"; goto exit_err; } status = ucc_init(&lib_params, lib_config, &lib_h); ucc_lib_config_release(lib_config); if (status != UCC_OK) { err_msg << "ucc_init failed"; goto exit_err; } return; exit_err: err_msg << ": "<< ucc_status_string(status) << " (" << status << ")"; throw std::runtime_error(err_msg.str()); } UccProcess::~UccProcess() { EXPECT_EQ(UCC_OK, ucc_context_destroy(ctx_h)); EXPECT_EQ(UCC_OK, ucc_finalize(lib_h)); if (ctx_params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS) { for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) { ucc_free(onesided_buf[i]); } } } ucc_status_t UccTeam::allgather(void *src_buf, void *recv_buf, size_t size, void *coll_info, void **request) { UccTeam::allgather_coll_info_t *ci = (UccTeam::allgather_coll_info_t *)coll_info; int my_rank = ci->my_rank; ci->self->ag[my_rank].sbuf = src_buf; ci->self->ag[my_rank].rbuf = recv_buf; ci->self->ag[my_rank].len = size; ci->self->ag[my_rank].phase = UccTeam::AG_READY; *request = (void *)ci; return UCC_OK; } void UccTeam::test_allgather(size_t msglen) { int *sbufs[n_procs]; int *rbufs[n_procs]; size_t count = msglen/sizeof(int); std::vector cis; std::vector reqs; for (int i=0; iself->n_procs; switch (ci->self->ag[ci->my_rank].phase) { case UccTeam::AG_READY: for (int i = 0; i < n_procs; i++) { if ((ci->self->ag[i].phase == UccTeam::AG_INIT) || (ci->self->ag[i].phase == UccTeam::AG_COMPLETE)) { return UCC_INPROGRESS; } } for (int i = 0; i < n_procs; i++) { memcpy((void *)((ptrdiff_t)ci->self->ag[ci->my_rank].rbuf + i * ci->self->ag[i].len), ci->self->ag[i].sbuf, ci->self->ag[i].len); } ci->self->ag[ci->my_rank].phase = UccTeam::AG_COPY_DONE; ; ci->self->copy_complete_count++; break; case UccTeam::AG_COPY_DONE: if (ci->my_rank == 0 && ci->self->copy_complete_count == n_procs) { for (int i = 0; i < n_procs; i++) { ci->self->ag[i].phase = UccTeam::AG_COMPLETE; } ci->self->copy_complete_count = 0; } break; case UccTeam::AG_COMPLETE: return UCC_OK; default: break; } return UCC_INPROGRESS; } ucc_status_t UccTeam::req_free(void *request) { UccTeam::allgather_coll_info_t *ci = (UccTeam::allgather_coll_info_t *)request; ci->self->ag[ci->my_rank].phase = UccTeam::AG_INIT; return UCC_OK; } uint64_t rank_map_cb(uint64_t ep, void *cb_ctx) { UccTeam *team = (UccTeam*)cb_ctx; return (uint64_t)team->procs[(int)ep].p.get()->job_rank; } void UccTeam::init_team(bool use_team_ep_map, bool use_ep_range, bool is_onesided) { ucc_team_params_t team_params; std::vector cis; ucc_status_t status; for (int i = 0; i < n_procs; i++) { cis.push_back(new allgather_coll_info); cis.back()->self = this; cis.back()->my_rank = i; if (use_ep_range) { team_params.ep = i; team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE; } else { team_params.mask = 0; } if (use_team_ep_map) { team_params.mask |= UCC_TEAM_PARAM_FIELD_EP_MAP; team_params.ep_map.type = UCC_EP_MAP_CB; team_params.ep_map.ep_num = n_procs; team_params.ep_map.cb.cb = rank_map_cb; team_params.ep_map.cb.cb_ctx = (void*)this; } else { team_params.oob.allgather = allgather; team_params.oob.req_test = req_test; team_params.oob.req_free = req_free; team_params.oob.coll_info = (void *)cis.back(); team_params.oob.n_oob_eps = n_procs; team_params.oob.oob_ep = i; team_params.mask |= UCC_TEAM_PARAM_FIELD_OOB; } if (is_onesided) { team_params.mask |= UCC_TEAM_PARAM_FIELD_FLAGS; team_params.flags = UCC_TEAM_FLAG_COLL_WORK_BUFFER; } EXPECT_EQ(UCC_OK, ucc_team_create_post(&(procs[i].p.get()->ctx_h), 1, &team_params, &(procs[i].team))); } int all_done = 0; while (!all_done) { all_done = 1; for (int i = 0; i < n_procs; i++) { ucc_context_progress(procs[i].p.get()->ctx_h); status = ucc_team_create_test(procs[i].team); ASSERT_GE(status, 0); if (UCC_INPROGRESS == status) { all_done = 0; } } } for (auto c : cis) { delete c; } } void UccTeam::destroy_team() { ucc_status_t status; bool all_done; do { all_done = true; for (auto &p : procs) { if (p.team) { status = ucc_team_destroy(p.team); if (UCC_OK == status) { p.team = NULL; } else if (status < 0) { return; } else { all_done = false; } } } } while (!all_done); } void UccTeam::progress() { for (auto &p : procs) { ucc_context_progress(p.p->ctx_h); } } UccTeam::UccTeam(std::vector &_procs, bool use_team_ep_map, bool use_ep_range, bool is_onesided) { n_procs = _procs.size(); ag.resize(n_procs); for (auto &p : _procs) { procs.push_back(proc(p)); } for (auto &a : ag) { a.phase = AG_INIT; } copy_complete_count = 0; init_team(use_team_ep_map, use_ep_range, is_onesided); // test_allgather(128); } UccTeam::~UccTeam() { destroy_team(); } UccJob::UccJob(int _n_procs, ucc_job_ctx_mode_t _ctx_mode, ucc_job_env_t vars) : ta(_n_procs), n_procs(_n_procs), ctx_mode(_ctx_mode) { ucc_job_env_t env_bkp; char *var; /* NCCL TL is disabled since it currently can not support non-blocking team creation. */ vars.push_back({"UCC_TL_NCCL_TUNE", "0"}); vars.push_back({"UCC_TL_RCCL_TUNE", "0"}); /* CUDA TL is disabled since cuda context is not initialized in threads. */ vars.push_back({"UCC_TL_CUDA_TUNE", "0"}); /* GDR is temporarily disabled due to known issue that may result in a hang in the destruction flow */ vars.push_back({"UCX_IB_GPU_DIRECT_RDMA", "no"}); for (auto &v : vars) { var = std::getenv(v.first.c_str()); if (var) { /* found env - back it up for later restore after processes creation */ env_bkp.push_back(ucc_env_var_t(v.first, var)); } setenv(v.first.c_str(), v.second.c_str(), 1); } for (int i = 0; i < n_procs; i++) { procs.push_back(std::make_shared(i)); } create_context(); for (auto &v : env_bkp) { /*restore original env */ setenv(v.first.c_str(), v.second.c_str(), 1); } } void thread_allgather(void *src_buf, void *recv_buf, size_t size, ThreadAllgatherReq *ta_req) { ThreadAllgather *ta = ta_req->ta; while (ta->ready_count > ta->n_procs) { std::this_thread::yield(); } ta->lock.lock(); if (!ta->buffer) { ucc_assert(0 == ta->ready_count); ta->buffer = malloc(size * ta->n_procs); ta->ready_count = 0; } memcpy((void*)((ptrdiff_t)ta->buffer + size * ta_req->rank), src_buf, size); ta->ready_count++; ta->lock.unlock(); while (ta->ready_count < ta->n_procs) { std::this_thread::yield(); } memcpy(recv_buf, ta->buffer, size * ta->n_procs); ta->lock.lock(); ta->ready_count++; if (ta->ready_count == 2 * ta->n_procs) { free(ta->buffer); ta->buffer = NULL; ta->ready_count = 0; } ta->lock.unlock(); ta_req->status = UCC_OK; } ucc_status_t thread_allgather_start(void *src_buf, void *recv_buf, size_t size, void *coll_info, void **request) { ThreadAllgatherReq *ta_req = (ThreadAllgatherReq*)coll_info; *request = coll_info; while (ta_req->status != UCC_OPERATION_INITIALIZED) { std::this_thread::yield(); } ta_req->status = UCC_INPROGRESS; ta_req->t = std::thread(thread_allgather, src_buf, recv_buf, size, ta_req); return UCC_OK; } ucc_status_t thread_allgather_req_test(void *request) { ThreadAllgatherReq *ta_req = (ThreadAllgatherReq*)request; return ta_req->status; } ucc_status_t thread_allgather_req_free(void *request) { ThreadAllgatherReq *ta_req = (ThreadAllgatherReq*)request; ta_req->t.join(); ta_req->status = UCC_OPERATION_INITIALIZED; return UCC_OK; } void proc_context_create(UccProcess_h proc, int id, ThreadAllgather *ta, bool is_global) { const int nnodes = 2; const int nsockets = 2; const int nnumas = 3; ucc_status_t status; ucc_context_config_h ctx_config; std::stringstream err_msg; ucc_proc_info_t proc_info; int node, local_ppn, local_rank, job_size, block; status = ucc_context_config_read(proc->lib_h, NULL, &ctx_config); if (status != UCC_OK) { err_msg << "ucc_context_config_read failed"; goto exit_err; } if (is_global) { proc->ctx_params.mask |= UCC_CONTEXT_PARAM_FIELD_OOB; proc->ctx_params.oob.allgather = thread_allgather_start; proc->ctx_params.oob.req_test = thread_allgather_req_test; proc->ctx_params.oob.req_free = thread_allgather_req_free; proc->ctx_params.oob.coll_info = (void*) &ta->reqs[id]; proc->ctx_params.oob.n_oob_eps = ta->n_procs; proc->ctx_params.oob.oob_ep = id; /* Simulate multi-node topology for larger gtest coverage */ job_size = ta->n_procs; block = ucc_buffer_block_count(job_size, nnodes, 0); node = id / block; local_ppn = ucc_buffer_block_count(job_size, nnodes, node); local_rank = id - ucc_buffer_block_offset(job_size, nnodes, node); proc_info.host_hash = node + 1; block = ucc_buffer_block_count(local_ppn, nsockets, 0); proc_info.socket_id = local_rank / block; block = ucc_buffer_block_count(local_ppn, nnumas, 0); proc_info.numa_id = local_rank / block; proc_info.pid = id + 1; } else { proc_info = ucc_local_proc; } status = ucc_context_create_proc_info(proc->lib_h, &proc->ctx_params, ctx_config, &proc->ctx_h, &proc_info); ucc_context_config_release(ctx_config); if (status != UCC_OK) { err_msg << "ucc_context_create failed"; goto exit_err; } return; exit_err: err_msg << ": "<< ucc_status_string(status) << " (" << status << ")"; throw std::runtime_error(err_msg.str()); } void proc_context_create_mem_params(UccProcess_h proc, int id, ThreadAllgather *ta) { ucc_status_t status; ucc_context_config_h ctx_config; std::stringstream err_msg; ucc_mem_map_t map[UCC_TEST_N_MEM_SEGMENTS]; status = ucc_context_config_read(proc->lib_h, NULL, &ctx_config); if (status != UCC_OK) { err_msg << "ucc_context_config_read failed"; goto exit_err; } for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) { proc->onesided_buf[i] = ucc_calloc(UCC_TEST_MEM_SEGMENT_SIZE, 1, "onesided_buffer"); EXPECT_NE(proc->onesided_buf[i], nullptr); map[i].address = proc->onesided_buf[i]; map[i].len = UCC_TEST_MEM_SEGMENT_SIZE; } proc->ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_OOB; proc->ctx_params.mask |= UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS; proc->ctx_params.oob.allgather = thread_allgather_start; proc->ctx_params.oob.req_test = thread_allgather_req_test; proc->ctx_params.oob.req_free = thread_allgather_req_free; proc->ctx_params.oob.coll_info = (void *)&ta->reqs[id]; proc->ctx_params.oob.n_oob_eps = ta->n_procs; proc->ctx_params.oob.oob_ep = id; proc->ctx_params.mem_params.segments = map; proc->ctx_params.mem_params.n_segments = UCC_TEST_N_MEM_SEGMENTS; status = ucc_context_create(proc->lib_h, &proc->ctx_params, ctx_config, &proc->ctx_h); ucc_context_config_release(ctx_config); if (status != UCC_OK) { err_msg << "ucc_context_create for one-sided context failed"; goto exit_err; } return; exit_err: err_msg << ": " << ucc_status_string(status) << " (" << status << ")"; throw std::runtime_error(err_msg.str()); } void UccJob::create_context() { std::vector workers; for (auto i = 0; i < procs.size(); i++) { if (ctx_mode == UCC_JOB_CTX_GLOBAL_ONESIDED) { workers.push_back( std::thread(proc_context_create_mem_params, procs[i], i, &ta)); } else { workers.push_back(std::thread(proc_context_create, procs[i], i, &ta, ctx_mode == UCC_JOB_CTX_GLOBAL)); } } for (auto i = 0; i < procs.size(); i++) { workers[i].join(); } } void thread_proc_destruct(std::vector *procs, int i) { ucc_assert(true == (*procs)[i].unique()); (*procs)[i] = NULL; } UccJob::~UccJob() { std::vector workers; if (this == UccJob::staticUccJob) { staticTeams.clear(); } for (int i = 0; i < n_procs; i++) { workers.push_back(std::thread(thread_proc_destruct, &procs, i)); } for (int i = 0; i < n_procs; i++) { workers[i].join(); } } UccJob* UccJob::staticUccJob = NULL; UccJob* UccJob::getStaticJob() { if (!staticUccJob) { staticUccJob = new UccJob(UccJob::staticUccJobSize); } return staticUccJob; } std::vector UccJob::staticTeams; const std::vector &UccJob::getStaticTeams() { std::vector teamSizes(std::begin(staticTeamSizes), std::end(staticTeamSizes)); if (0 == staticTeams.size()) { for (auto ts : teamSizes) { if (ts == 1 && !tl_self_available()) { /* don't use team_size = 1 if there is no tl/self. we can't modify nStaticTeams, so just use some other team_size */ ts = 3; } staticTeams.push_back(getStaticJob()->create_team(ts)); } /* Create one more team with reversed ranks order */ std::vector ranks; for (auto r = staticUccJobSize - 1; r >= 0; r--) { ranks.push_back(r); } staticTeams.push_back(getStaticJob()->create_team(ranks, true)); } return staticTeams; } void UccJob::cleanup() { if (staticUccJob) { delete staticUccJob; } } UccTeam_h UccJob::create_team(int _n_procs, bool use_team_ep_map, bool use_ep_range, bool is_onesided) { EXPECT_GE(n_procs, _n_procs); std::vector team_procs; for (int i = 0; i < _n_procs; i++) { team_procs.push_back(procs[i]); } return std::make_shared(team_procs, use_team_ep_map, use_ep_range, is_onesided); } UccTeam_h UccJob::create_team(std::vector &ranks, bool use_team_ep_map, bool use_ep_range, bool is_onesided) { EXPECT_GE(n_procs, ranks.size()); std::vector team_procs; for (int i = 0; i < ranks.size(); i++) { team_procs.push_back(procs[ranks[i]]); } return std::make_shared(team_procs, use_team_ep_map, use_ep_range, is_onesided); } UccReq::UccReq(UccTeam_h _team, ucc_coll_args_t *args) : team(_team) { ucc_coll_req_h req; for (auto &p : team->procs) { if (UCC_OK != ucc_collective_init(args, &req, p.team)) { goto err; } reqs.push_back(req); } return; err: reqs.clear(); } UccReq::UccReq(UccTeam_h _team, UccCollCtxVec ctxs) : team(_team) { std::vector err_st; ucc_coll_req_h req; ucc_status_t st; EXPECT_EQ(team->procs.size(), ctxs.size()); status = UCC_OK; for (auto i = 0; i < team->procs.size(); i++) { if (!ctxs[i]) { continue; } if (UCC_OK !=(st = ucc_collective_init(ctxs[i]->args, &req, team->procs[i].team))) { err_st.push_back(st); } else { reqs.push_back(req); } } if (err_st.size() > 0) { /* All error status should be equal, otherwise it is real fatal error. Only expected error is NOT_SUPPORTED. If collective init returns NOT_SUPPORTED it has to be symmetric for all ranks */ if (!std::equal(err_st.begin() + 1, err_st.end(), err_st.begin()) || err_st.size() != team->procs.size() || err_st[0] != UCC_ERR_NOT_SUPPORTED) { status = UCC_ERR_NO_MESSAGE; } else { ucc_assert(err_st[0] = UCC_ERR_NOT_SUPPORTED); status = err_st[0]; } } } UccReq::~UccReq() { for (auto r : reqs) { EXPECT_EQ(UCC_OK, ucc_collective_finalize(r)); } } void UccReq::start() { ucc_status_t st; for (auto r : reqs) { st = ucc_collective_post(r); ASSERT_EQ(UCC_OK, st); st = ucc_collective_test(r); ASSERT_NE(UCC_OPERATION_INITIALIZED, st); } } ucc_status_t UccReq::test() { ucc_status_t st = UCC_OK; for (auto r : reqs) { st = ucc_collective_test(r); if (UCC_OK != st) { break; } } return st; } ucc_status_t UccReq::wait() { ucc_status_t st; while (UCC_OK != (st = test())) { if (st < 0) { break; } team->progress(); } return st; } void UccReq::waitall(std::vector &reqs) { bool alldone = false; ucc_status_t status; while (!alldone) { alldone = true; for (auto &r : reqs) { if (UCC_OK != (status = r.test())) { if (status < 0) { return; } alldone = false; r.team->progress(); } } } } void UccReq::startall(std::vector &reqs) { for (auto &r : reqs) { r.start(); } } void UccCollArgs::set_mem_type(ucc_memory_type_t _mt) { mem_type = _mt; } void UccCollArgs::set_inplace(gtest_ucc_inplace_t _inplace) { inplace = _inplace; } void UccCollArgs::set_contig(bool _is_contig) { is_contig = _is_contig; } void clear_buffer(void *_buf, size_t size, ucc_memory_type_t mt, uint8_t value) { void *buf = _buf; if (mt != UCC_MEMORY_TYPE_HOST) { buf = ucc_malloc(size, "buf"); ASSERT_NE(0, (uintptr_t)buf); } memset(buf, value, size); if (UCC_MEMORY_TYPE_HOST != mt) { UCC_CHECK(ucc_mc_memcpy(_buf, buf, size, mt, UCC_MEMORY_TYPE_HOST)); ucc_free(buf); } } bool tl_self_available() { ucc_tl_context_t *tl_ctx; ucc_status_t status; status = ucc_tl_context_get(UccJob::getStaticJob()->procs[0]->ctx_h, "self", &tl_ctx); if (UCC_OK != status) { return false; } ucc_tl_context_put(tl_ctx); return true; } openucx-ucc-ec0bc8a/test/gtest/common/test_ucc.h0000664000175000017500000002173315133731560022267 0ustar alastairalastair/** * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #ifndef TEST_UCC_H #define TEST_UCC_H #include "test.h" extern "C" { #include "components/mc/ucc_mc.h" #include "utils/ucc_malloc.h" #include #include } #include #include #include #include #include #include #include typedef struct { ucc_mc_buffer_header_t *dst_mc_header; ucc_mc_buffer_header_t *src_mc_header; void *init_buf; size_t rbuf_size; ucc_coll_args_t *args; } gtest_ucc_coll_ctx_t; typedef std::vector UccCollCtxVec; typedef enum { TEST_NO_INPLACE, TEST_INPLACE } gtest_ucc_inplace_t; class UccCollArgs { protected: ucc_memory_type_t mem_type; gtest_ucc_inplace_t inplace; bool is_contig; void alltoallx_init_buf(int src_rank, int dst_rank, uint8_t *buf, size_t len) { for (int i = 0; i < len; i++) { buf[i] = (uint8_t)(((src_rank + len - i) * (dst_rank + 1)) % UINT8_MAX); } } int alltoallx_validate_buf(int src_rank, int dst_rank, uint8_t *buf, size_t len) { int err = 0; for (int i = 0; i < len; i ++) { uint8_t expected = (uint8_t) (((dst_rank + len - i) * (src_rank + 1)) % UINT8_MAX); if (buf[i] != expected) { err++; } } return err; } public: UccCollArgs() { // defaults mem_type = UCC_MEMORY_TYPE_HOST; inplace = TEST_NO_INPLACE; } virtual ~UccCollArgs() {} virtual void data_init(int nprocs, ucc_datatype_t dtype, size_t count, UccCollCtxVec &args, bool persistent = false) = 0; virtual void data_fini(UccCollCtxVec args) = 0; virtual bool data_validate(UccCollCtxVec args) = 0; void set_mem_type(ucc_memory_type_t _mt); void set_inplace(gtest_ucc_inplace_t _inplace); void set_contig(bool _contig); }; #define SET_MEM_TYPE(_mt) do { \ if (UCC_OK != ucc_mc_available(_mt)) { \ GTEST_SKIP(); \ } \ this->mem_type = _mt; \ } while (0) class ThreadAllgather; class ThreadAllgatherReq { public: ThreadAllgather *ta; int rank; ucc_status_t status; std::thread t; ThreadAllgatherReq(ThreadAllgather *_ta, int _rank) : ta(_ta), rank(_rank) { status = UCC_OPERATION_INITIALIZED; }; }; class ThreadAllgather { public: int n_procs; std::atomic ready_count; void *buffer; std::mutex lock; std::vector reqs; ThreadAllgather(int _n_procs) : n_procs(_n_procs), ready_count(0), buffer(NULL) { for (auto i = 0; i < _n_procs; i++) { reqs.push_back(ThreadAllgatherReq(this, i)); } }; ~ThreadAllgather() { buffer = NULL; ready_count = 0; } }; /* A single processes in a Job that runs UCC. It has context and lib object */ class UccProcess { public: ucc_context_params_t ctx_params; static constexpr ucc_lib_params_t default_lib_params = { .mask = UCC_LIB_PARAM_FIELD_THREAD_MODE | UCC_LIB_PARAM_FIELD_COLL_TYPES, .thread_mode = UCC_THREAD_SINGLE, .coll_types = UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLGATHER | UCC_COLL_TYPE_ALLGATHERV | UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_BCAST | UCC_COLL_TYPE_GATHER | UCC_COLL_TYPE_SCATTER}; static constexpr ucc_context_params_t default_ctx_params = { .mask = UCC_CONTEXT_PARAM_FIELD_TYPE, .type = UCC_CONTEXT_EXCLUSIVE }; ucc_lib_h lib_h; ucc_context_h ctx_h; void * onesided_buf[3]; int job_rank; UccProcess(int _job_rank, const ucc_lib_params_t &lp = default_lib_params, const ucc_context_params_t &cp = default_ctx_params); ~UccProcess(); }; typedef std::shared_ptr UccProcess_h; /* Ucc team that consists of several processes. The team is created from UccJob environment */ class UccTeam { struct proc { UccProcess_h p; ucc_team_h team; proc(){}; proc(UccProcess_h _p) : p(_p) {}; }; typedef enum { AG_INIT, AG_READY, AG_COPY_DONE, AG_COMPLETE } allgather_phase_t; struct allgather_data { void *sbuf; void *rbuf; size_t len; allgather_phase_t phase; }; typedef struct allgather_coll_info { int my_rank; UccTeam *self; } allgather_coll_info_t; std::vector ag; void init_team(bool use_team_ep_map, bool use_ep_range, bool is_onesided); void destroy_team(); void test_allgather(size_t msglen); static ucc_status_t allgather(void *src_buf, void *recv_buf, size_t size, void *coll_info, void **request); static ucc_status_t req_test(void *request); static ucc_status_t req_free(void *request); int copy_complete_count; public: int n_procs; void progress(); std::vector procs; UccTeam(std::vector &_procs, bool use_team_ep_map = false, bool use_ep_range = true, bool is_onesided = false); ~UccTeam(); }; typedef std::shared_ptr UccTeam_h; typedef std::pair ucc_env_var_t; typedef std::vector ucc_job_env_t; /* UccJob - environent that has n_procs processes. Multiple UccTeams can be created from UccJob */ class UccJob { static UccJob* staticUccJob; static std::vector staticTeams; ThreadAllgather ta; public: typedef enum { UCC_JOB_CTX_LOCAL, UCC_JOB_CTX_GLOBAL, /*< ucc ctx create with OOB */ UCC_JOB_CTX_GLOBAL_ONESIDED } ucc_job_ctx_mode_t; static const int nStaticTeams = 5; static const int staticUccJobSize = 16; static constexpr int staticTeamSizes[nStaticTeams] = {1, 2, 8, 11, staticUccJobSize}; static void cleanup(); static UccJob* getStaticJob(); static const std::vector &getStaticTeams(); int n_procs; UccJob(int _n_procs = 2, ucc_job_ctx_mode_t _ctx_mode = UCC_JOB_CTX_GLOBAL, ucc_job_env_t vars = ucc_job_env_t()); ~UccJob(); std::vector procs; UccTeam_h create_team(int n_procs, bool use_team_ep_map = false, bool use_ep_range = true, bool is_onesided = false); UccTeam_h create_team(std::vector &ranks, bool use_team_ep_map = false, bool use_ep_range = true, bool is_onesided = false); void create_context(); ucc_job_ctx_mode_t ctx_mode; }; class UccReq { UccTeam_h team; /* Make copy constructor and = private, to avoid req leak */ public: ucc_status_t status; UccReq(const UccReq&) = delete; UccReq& operator=(const UccReq&) = delete; UccReq(UccReq&& source) : team(source.team), status(source.status) { reqs.swap(source.reqs); }; std::vector reqs; UccReq(UccTeam_h _team, ucc_coll_args_t *args); UccReq(UccTeam_h _team, UccCollCtxVec args); ~UccReq(); void start(void); ucc_status_t wait(); ucc_status_t test(void); static void waitall(std::vector &reqs); static void startall(std::vector &reqs); }; #define DATA_FINI_ALL(_test, _ctx) for (auto &c : ctxs) { _test->data_fini(c); } #define CHECK_REQ_NOT_SUPPORTED_SKIP(_UccReq, _action) do{ \ if ((_UccReq).status == UCC_ERR_NOT_SUPPORTED) { \ _action; \ GTEST_SKIP(); \ } \ ASSERT_EQ(UCC_OK, (_UccReq).status); \ } while(0) void clear_buffer(void *_buf, size_t size, ucc_memory_type_t mt, uint8_t value); #define PREDEFINED_DTYPES \ ::testing::Values( \ UCC_DT_INT8, UCC_DT_INT16, UCC_DT_INT32, UCC_DT_INT64, UCC_DT_INT128, \ UCC_DT_UINT8, UCC_DT_UINT16, UCC_DT_UINT32, UCC_DT_UINT64, \ UCC_DT_UINT128, UCC_DT_FLOAT16, UCC_DT_FLOAT32, UCC_DT_FLOAT64, \ UCC_DT_BFLOAT16, UCC_DT_FLOAT128, UCC_DT_FLOAT32_COMPLEX, \ UCC_DT_FLOAT64_COMPLEX, UCC_DT_FLOAT128_COMPLEX) #define UCC_TEST_N_MEM_SEGMENTS 3 #define UCC_TEST_MEM_SEGMENT_SIZE (1 << 20) bool tl_self_available(); #endif openucx-ucc-ec0bc8a/test/gtest/common/test.h0000664000175000017500000001310215133731560021424 0ustar alastairalastair/** * Copyright (c) 2001-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (C) Huawei Technologies Co., Ltd. 2020. All rights reserved. * * See file LICENSE for terms. */ #ifndef UCC_TEST_BASE_H #define UCC_TEST_BASE_H extern "C" { } #include #undef I #define _Imag _Imaginary_I #include "gtest.h" #define UCC_CHECK(_call) EXPECT_EQ(UCC_OK, (_call)) #define UCC_TEST_SKIP_R(_str) GTEST_SKIP_(_str) namespace ucc { class test : public testing::Test { }; #define UCC_TEST_F(...) TEST_F(__VA_ARGS__) #define UCC_TEST_P(...) TEST_P(__VA_ARGS__) } #define ASSERT_FLOAT32_COMPLEX_EQ(expected, actual) \ do { \ static float expected_real = crealf(expected); \ static float expected_imaginary = cimagf(expected); \ static float actual_real = crealf(actual); \ static float actual_imaginary = cimagf(actual); \ ASSERT_PRED_FORMAT2( \ ::testing::internal::CmpHelperFloatingPointEQ, \ expected_real, actual_real); \ ASSERT_PRED_FORMAT2( \ ::testing::internal::CmpHelperFloatingPointEQ, \ expected_imaginary, actual_imaginary); \ } while (0) #define ASSERT_FLOAT64_COMPLEX_EQ(expected, actual) \ do { \ static double expected_real = creal(expected); \ static double expected_imaginary = cimag(expected); \ static double actual_real = creal(actual); \ static double actual_imaginary = cimag(actual); \ ASSERT_PRED_FORMAT2( \ ::testing::internal::CmpHelperFloatingPointEQ, \ expected_real, actual_real); \ ASSERT_PRED_FORMAT2( \ ::testing::internal::CmpHelperFloatingPointEQ, \ expected_imaginary, actual_imaginary); \ } while (0) #define ASSERT_FLOAT128_COMPLEX_EQ(expected, actual) \ do { \ static long double expected_real = creall(expected); \ static long double expected_imaginary = cimagl(expected); \ static long double actual_real = creall(actual); \ static long double actual_imaginary = cimagl(actual); \ ASSERT_PRED_FORMAT2( \ ::testing::internal::CmpHelperFloatingPointEQ, \ expected_real, actual_real); \ ASSERT_PRED_FORMAT2( \ ::testing::internal::CmpHelperFloatingPointEQ, \ expected_imaginary, actual_imaginary); \ } while (0) #define EXPECT_FLOAT32_COMPLEX_EQ(expected, actual) \ static float expected_real = crealf(expected); \ static float expected_imaginary = cimagf(expected); \ static float actual_real = crealf(actual); \ static float actual_imaginary = cimagf(actual); \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ expected_real, actual_real) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ expected_imaginary, actual_imaginary) #define EXPECT_FLOAT64_COMPLEX_EQ(expected, actual) \ static double expected_real = creal(expected); \ static double expected_imaginary = cimag(expected); \ static double actual_real = creal(actual); \ static double actual_imaginary = cimag(actual); \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ expected_real, actual_real) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ expected_imaginary, actual_imaginary) #define EXPECT_FLOAT128_COMPLEX_EQ(expected, actual) \ static long double expected_real = creall(expected); \ static long double expected_imaginary = cimagl(expected); \ static long double actual_real = creall(actual); \ static long double actual_imaginary = cimagl(actual); \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ expected_real, actual_real) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ expected_imaginary, actual_imaginary) #endif openucx-ucc-ec0bc8a/test/gtest/common/gtest.h0000664000175000017500000231324115133731560021604 0ustar alastairalastair// Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file defines the public API for Google Test. It should be // included by any test program that uses Google Test. // // IMPORTANT NOTE: Due to limitation of the C++ language, we have to // leave some internal implementation details in this header file. // They are clearly marked by comments like this: // // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. // // Such code is NOT meant to be used by a user directly, and is subject // to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user // program! // // Acknowledgment: Google Test borrowed the idea of automatic test // registration from Barthelemy Dagenais' (barthelemy@prologique.com) // easyUnit framework. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_H_ #define GTEST_INCLUDE_GTEST_GTEST_H_ #include #include #include #include #include #include // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file declares functions and macros used internally by // Google Test. They are subject to change without notice. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Low-level types and utilities for porting Google Test to various // platforms. All macros ending with _ and symbols defined in an // internal namespace are subject to change without notice. Code // outside Google Test MUST NOT USE THEM DIRECTLY. Macros that don't // end with _ are part of Google Test's public API and can be used by // code outside Google Test. // // This file is fundamental to Google Test. All other Google Test source // files are expected to #include this. Therefore, it cannot #include // any other Google Test header. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ // Environment-describing macros // ----------------------------- // // Google Test can be used in many different environments. Macros in // this section tell Google Test what kind of environment it is being // used in, such that Google Test can provide environment-specific // features and implementations. // // Google Test tries to automatically detect the properties of its // environment, so users usually don't need to worry about these // macros. However, the automatic detection is not perfect. // Sometimes it's necessary for a user to define some of the following // macros in the build script to override Google Test's decisions. // // If the user doesn't define a macro in the list, Google Test will // provide a default definition. After this header is #included, all // macros in this list will be defined to either 1 or 0. // // Notes to maintainers: // - Each macro here is a user-tweakable knob; do not grow the list // lightly. // - Use #if to key off these macros. Don't use #ifdef or "#if // defined(...)", which will not work as these macros are ALWAYS // defined. // // GTEST_HAS_CLONE - Define it to 1/0 to indicate that clone(2) // is/isn't available. // GTEST_HAS_EXCEPTIONS - Define it to 1/0 to indicate that exceptions // are enabled. // GTEST_HAS_POSIX_RE - Define it to 1/0 to indicate that POSIX regular // expressions are/aren't available. // GTEST_HAS_PTHREAD - Define it to 1/0 to indicate that // is/isn't available. // GTEST_HAS_RTTI - Define it to 1/0 to indicate that RTTI is/isn't // enabled. // GTEST_HAS_STD_WSTRING - Define it to 1/0 to indicate that // std::wstring does/doesn't work (Google Test can // be used where std::wstring is unavailable). // GTEST_HAS_SEH - Define it to 1/0 to indicate whether the // compiler supports Microsoft's "Structured // Exception Handling". // GTEST_HAS_STREAM_REDIRECTION // - Define it to 1/0 to indicate whether the // platform supports I/O stream redirection using // dup() and dup2(). // GTEST_LINKED_AS_SHARED_LIBRARY // - Define to 1 when compiling tests that use // Google Test as a shared library (known as // DLL on Windows). // GTEST_CREATE_SHARED_LIBRARY // - Define to 1 when compiling Google Test itself // as a shared library. // GTEST_DEFAULT_DEATH_TEST_STYLE // - The default value of --gtest_death_test_style. // The legacy default has been "fast" in the open // source version since 2008. The recommended value // is "threadsafe", and can be set in // custom/gtest-port.h. // Platform-indicating macros // -------------------------- // // Macros indicating the platform on which Google Test is being used // (a macro is defined to 1 if compiled on the given platform; // otherwise UNDEFINED -- it's never defined to 0.). Google Test // defines these macros automatically. Code outside Google Test MUST // NOT define them. // // GTEST_OS_AIX - IBM AIX // GTEST_OS_CYGWIN - Cygwin // GTEST_OS_DRAGONFLY - DragonFlyBSD // GTEST_OS_FREEBSD - FreeBSD // GTEST_OS_FUCHSIA - Fuchsia // GTEST_OS_GNU_KFREEBSD - GNU/kFreeBSD // GTEST_OS_HAIKU - Haiku // GTEST_OS_HPUX - HP-UX // GTEST_OS_LINUX - Linux // GTEST_OS_LINUX_ANDROID - Google Android // GTEST_OS_MAC - Mac OS X // GTEST_OS_IOS - iOS // GTEST_OS_NACL - Google Native Client (NaCl) // GTEST_OS_NETBSD - NetBSD // GTEST_OS_OPENBSD - OpenBSD // GTEST_OS_OS2 - OS/2 // GTEST_OS_QNX - QNX // GTEST_OS_SOLARIS - Sun Solaris // GTEST_OS_WINDOWS - Windows (Desktop, MinGW, or Mobile) // GTEST_OS_WINDOWS_DESKTOP - Windows Desktop // GTEST_OS_WINDOWS_MINGW - MinGW // GTEST_OS_WINDOWS_MOBILE - Windows Mobile // GTEST_OS_WINDOWS_PHONE - Windows Phone // GTEST_OS_WINDOWS_RT - Windows Store App/WinRT // GTEST_OS_ZOS - z/OS // // Among the platforms, Cygwin, Linux, Mac OS X, and Windows have the // most stable support. Since core members of the Google Test project // don't have access to other platforms, support for them may be less // stable. If you notice any problems on your platform, please notify // googletestframework@googlegroups.com (patches for fixing them are // even more welcome!). // // It is possible that none of the GTEST_OS_* macros are defined. // Feature-indicating macros // ------------------------- // // Macros indicating which Google Test features are available (a macro // is defined to 1 if the corresponding feature is supported; // otherwise UNDEFINED -- it's never defined to 0.). Google Test // defines these macros automatically. Code outside Google Test MUST // NOT define them. // // These macros are public so that portable tests can be written. // Such tests typically surround code using a feature with an #if // which controls that code. For example: // // #if GTEST_HAS_DEATH_TEST // EXPECT_DEATH(DoSomethingDeadly()); // #endif // // GTEST_HAS_DEATH_TEST - death tests // GTEST_HAS_TYPED_TEST - typed tests // GTEST_HAS_TYPED_TEST_P - type-parameterized tests // GTEST_IS_THREADSAFE - Google Test is thread-safe. // GOOGLETEST_CM0007 DO NOT DELETE // GTEST_USES_POSIX_RE - enhanced POSIX regex is used. Do not confuse with // GTEST_HAS_POSIX_RE (see above) which users can // define themselves. // GTEST_USES_SIMPLE_RE - our own simple regex is used; // the above RE\b(s) are mutually exclusive. // Misc public macros // ------------------ // // GTEST_FLAG(flag_name) - references the variable corresponding to // the given Google Test flag. // Internal utilities // ------------------ // // The following macros and utilities are for Google Test's INTERNAL // use only. Code outside Google Test MUST NOT USE THEM DIRECTLY. // // Macros for basic C++ coding: // GTEST_AMBIGUOUS_ELSE_BLOCKER_ - for disabling a gcc warning. // GTEST_ATTRIBUTE_UNUSED_ - declares that a class' instances or a // variable don't have to be used. // GTEST_DISALLOW_ASSIGN_ - disables operator=. // GTEST_DISALLOW_COPY_AND_ASSIGN_ - disables copy ctor and operator=. // GTEST_MUST_USE_RESULT_ - declares that a function's result must be used. // GTEST_INTENTIONAL_CONST_COND_PUSH_ - start code section where MSVC C4127 is // suppressed (constant conditional). // GTEST_INTENTIONAL_CONST_COND_POP_ - finish code section where MSVC C4127 // is suppressed. // // Synchronization: // Mutex, MutexLock, ThreadLocal, GetThreadCount() // - synchronization primitives. // // Regular expressions: // RE - a simple regular expression class using the POSIX // Extended Regular Expression syntax on UNIX-like platforms // GOOGLETEST_CM0008 DO NOT DELETE // or a reduced regular exception syntax on other // platforms, including Windows. // Logging: // GTEST_LOG_() - logs messages at the specified severity level. // LogToStderr() - directs all log messages to stderr. // FlushInfoLog() - flushes informational log messages. // // Stdout and stderr capturing: // CaptureStdout() - starts capturing stdout. // GetCapturedStdout() - stops capturing stdout and returns the captured // string. // CaptureStderr() - starts capturing stderr. // GetCapturedStderr() - stops capturing stderr and returns the captured // string. // // Integer types: // TypeWithSize - maps an integer to a int type. // Int32, UInt32, Int64, UInt64, TimeInMillis // - integers of known sizes. // BiggestInt - the biggest signed integer type. // // Command-line utilities: // GTEST_DECLARE_*() - declares a flag. // GTEST_DEFINE_*() - defines a flag. // GetInjectableArgvs() - returns the command line as a vector of strings. // // Environment variable utilities: // GetEnv() - gets the value of an environment variable. // BoolFromGTestEnv() - parses a bool environment variable. // Int32FromGTestEnv() - parses an Int32 environment variable. // StringFromGTestEnv() - parses a string environment variable. // // Deprecation warnings: // GTEST_INTERNAL_DEPRECATED(message) - attribute marking a function as // deprecated; calling a marked function // should generate a compiler warning #include // for isspace, etc #include // for ptrdiff_t #include #include #include #include #include #ifndef _WIN32_WCE # include # include #endif // !_WIN32_WCE #if defined __APPLE__ # include # include #endif #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include #include #include // NOLINT // Copyright 2015, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file defines the GTEST_OS_* macro. // It is separate from gtest-port.h so that custom/gtest-port.h can include it. #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_ARCH_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_ARCH_H_ // Determines the platform on which Google Test is compiled. #ifdef __CYGWIN__ # define GTEST_OS_CYGWIN 1 # elif defined(__MINGW__) || defined(__MINGW32__) || defined(__MINGW64__) # define GTEST_OS_WINDOWS_MINGW 1 # define GTEST_OS_WINDOWS 1 #elif defined _WIN32 # define GTEST_OS_WINDOWS 1 # ifdef _WIN32_WCE # define GTEST_OS_WINDOWS_MOBILE 1 # elif defined(WINAPI_FAMILY) # include # if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) # define GTEST_OS_WINDOWS_DESKTOP 1 # elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PHONE_APP) # define GTEST_OS_WINDOWS_PHONE 1 # elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) # define GTEST_OS_WINDOWS_RT 1 # elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_TV_TITLE) # define GTEST_OS_WINDOWS_PHONE 1 # define GTEST_OS_WINDOWS_TV_TITLE 1 # else // WINAPI_FAMILY defined but no known partition matched. // Default to desktop. # define GTEST_OS_WINDOWS_DESKTOP 1 # endif # else # define GTEST_OS_WINDOWS_DESKTOP 1 # endif // _WIN32_WCE #elif defined __OS2__ # define GTEST_OS_OS2 1 #elif defined __APPLE__ # define GTEST_OS_MAC 1 # if TARGET_OS_IPHONE # define GTEST_OS_IOS 1 # endif #elif defined __DragonFly__ # define GTEST_OS_DRAGONFLY 1 #elif defined __FreeBSD__ # define GTEST_OS_FREEBSD 1 #elif defined __Fuchsia__ # define GTEST_OS_FUCHSIA 1 #elif defined(__GLIBC__) && defined(__FreeBSD_kernel__) # define GTEST_OS_GNU_KFREEBSD 1 #elif defined __linux__ # define GTEST_OS_LINUX 1 # if defined __ANDROID__ # define GTEST_OS_LINUX_ANDROID 1 # endif #elif defined __MVS__ # define GTEST_OS_ZOS 1 #elif defined(__sun) && defined(__SVR4) # define GTEST_OS_SOLARIS 1 #elif defined(_AIX) # define GTEST_OS_AIX 1 #elif defined(__hpux) # define GTEST_OS_HPUX 1 #elif defined __native_client__ # define GTEST_OS_NACL 1 #elif defined __NetBSD__ # define GTEST_OS_NETBSD 1 #elif defined __OpenBSD__ # define GTEST_OS_OPENBSD 1 #elif defined __QNX__ # define GTEST_OS_QNX 1 #elif defined(__HAIKU__) #define GTEST_OS_HAIKU 1 #endif // __CYGWIN__ #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_ARCH_H_ // Copyright 2015, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Injection point for custom user configurations. See README for details // // ** Custom implementation starts here ** #ifndef GTEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ #endif // GTEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ #if !defined(GTEST_DEV_EMAIL_) # define GTEST_DEV_EMAIL_ "googletestframework@@googlegroups.com" # define GTEST_FLAG_PREFIX_ "gtest_" # define GTEST_FLAG_PREFIX_DASH_ "gtest-" # define GTEST_FLAG_PREFIX_UPPER_ "GTEST_" # define GTEST_NAME_ "Google Test" # define GTEST_PROJECT_URL_ "https://github.com/google/googletest/" #endif // !defined(GTEST_DEV_EMAIL_) #if !defined(GTEST_INIT_GOOGLE_TEST_NAME_) # define GTEST_INIT_GOOGLE_TEST_NAME_ "testing::InitGoogleTest" #endif // !defined(GTEST_INIT_GOOGLE_TEST_NAME_) // Determines the version of gcc that is used to compile this. #ifdef __GNUC__ // 40302 means version 4.3.2. # define GTEST_GCC_VER_ \ (__GNUC__*10000 + __GNUC_MINOR__*100 + __GNUC_PATCHLEVEL__) #endif // __GNUC__ // Macros for disabling Microsoft Visual C++ warnings. // // GTEST_DISABLE_MSC_WARNINGS_PUSH_(4800 4385) // /* code that triggers warnings C4800 and C4385 */ // GTEST_DISABLE_MSC_WARNINGS_POP_() #if defined(_MSC_VER) # define GTEST_DISABLE_MSC_WARNINGS_PUSH_(warnings) \ __pragma(warning(push)) \ __pragma(warning(disable: warnings)) # define GTEST_DISABLE_MSC_WARNINGS_POP_() \ __pragma(warning(pop)) #else // Not all compilers are MSVC # define GTEST_DISABLE_MSC_WARNINGS_PUSH_(warnings) # define GTEST_DISABLE_MSC_WARNINGS_POP_() #endif // Clang on Windows does not understand MSVC's pragma warning. // We need clang-specific way to disable function deprecation warning. #ifdef __clang__ # define GTEST_DISABLE_MSC_DEPRECATED_PUSH_() \ _Pragma("clang diagnostic push") \ _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") \ _Pragma("clang diagnostic ignored \"-Wdeprecated-implementations\"") #define GTEST_DISABLE_MSC_DEPRECATED_POP_() \ _Pragma("clang diagnostic pop") #else # define GTEST_DISABLE_MSC_DEPRECATED_PUSH_() \ GTEST_DISABLE_MSC_WARNINGS_PUSH_(4996) # define GTEST_DISABLE_MSC_DEPRECATED_POP_() \ GTEST_DISABLE_MSC_WARNINGS_POP_() #endif // Brings in definitions for functions used in the testing::internal::posix // namespace (read, write, close, chdir, isatty, stat). We do not currently // use them on Windows Mobile. #if GTEST_OS_WINDOWS # if !GTEST_OS_WINDOWS_MOBILE # include # include # endif // In order to avoid having to include , use forward declaration #if GTEST_OS_WINDOWS_MINGW && !defined(__MINGW64_VERSION_MAJOR) // MinGW defined _CRITICAL_SECTION and _RTL_CRITICAL_SECTION as two // separate (equivalent) structs, instead of using typedef typedef struct _CRITICAL_SECTION GTEST_CRITICAL_SECTION; #else // Assume CRITICAL_SECTION is a typedef of _RTL_CRITICAL_SECTION. // This assumption is verified by // WindowsTypesTest.CRITICAL_SECTIONIs_RTL_CRITICAL_SECTION. typedef struct _RTL_CRITICAL_SECTION GTEST_CRITICAL_SECTION; #endif #else // This assumes that non-Windows OSes provide unistd.h. For OSes where this // is not the case, we need to include headers that provide the functions // mentioned above. # include # include #endif // GTEST_OS_WINDOWS #if GTEST_OS_LINUX_ANDROID // Used to define __ANDROID_API__ matching the target NDK API level. # include // NOLINT #endif // Defines this to true if and only if Google Test can use POSIX regular // expressions. #ifndef GTEST_HAS_POSIX_RE # if GTEST_OS_LINUX_ANDROID // On Android, is only available starting with Gingerbread. # define GTEST_HAS_POSIX_RE (__ANDROID_API__ >= 9) # else # define GTEST_HAS_POSIX_RE (!GTEST_OS_WINDOWS) # endif #endif #if GTEST_USES_PCRE // The appropriate headers have already been included. #elif GTEST_HAS_POSIX_RE // On some platforms, needs someone to define size_t, and // won't compile otherwise. We can #include it here as we already // included , which is guaranteed to define size_t through // . # include // NOLINT # define GTEST_USES_POSIX_RE 1 #elif GTEST_OS_WINDOWS // is not available on Windows. Use our own simple regex // implementation instead. # define GTEST_USES_SIMPLE_RE 1 #else // may not be available on this platform. Use our own // simple regex implementation instead. # define GTEST_USES_SIMPLE_RE 1 #endif // GTEST_USES_PCRE #ifndef GTEST_HAS_EXCEPTIONS // The user didn't tell us whether exceptions are enabled, so we need // to figure it out. # if defined(_MSC_VER) && defined(_CPPUNWIND) // MSVC defines _CPPUNWIND to 1 if and only if exceptions are enabled. # define GTEST_HAS_EXCEPTIONS 1 # elif defined(__BORLANDC__) // C++Builder's implementation of the STL uses the _HAS_EXCEPTIONS // macro to enable exceptions, so we'll do the same. // Assumes that exceptions are enabled by default. # ifndef _HAS_EXCEPTIONS # define _HAS_EXCEPTIONS 1 # endif // _HAS_EXCEPTIONS # define GTEST_HAS_EXCEPTIONS _HAS_EXCEPTIONS # elif defined(__clang__) // clang defines __EXCEPTIONS if and only if exceptions are enabled before clang // 220714, but if and only if cleanups are enabled after that. In Obj-C++ files, // there can be cleanups for ObjC exceptions which also need cleanups, even if // C++ exceptions are disabled. clang has __has_feature(cxx_exceptions) which // checks for C++ exceptions starting at clang r206352, but which checked for // cleanups prior to that. To reliably check for C++ exception availability with // clang, check for // __EXCEPTIONS && __has_feature(cxx_exceptions). # define GTEST_HAS_EXCEPTIONS (__EXCEPTIONS && __has_feature(cxx_exceptions)) # elif defined(__GNUC__) && __EXCEPTIONS // gcc defines __EXCEPTIONS to 1 if and only if exceptions are enabled. # define GTEST_HAS_EXCEPTIONS 1 # elif defined(__SUNPRO_CC) // Sun Pro CC supports exceptions. However, there is no compile-time way of // detecting whether they are enabled or not. Therefore, we assume that // they are enabled unless the user tells us otherwise. # define GTEST_HAS_EXCEPTIONS 1 # elif defined(__IBMCPP__) && __EXCEPTIONS // xlC defines __EXCEPTIONS to 1 if and only if exceptions are enabled. # define GTEST_HAS_EXCEPTIONS 1 # elif defined(__HP_aCC) // Exception handling is in effect by default in HP aCC compiler. It has to // be turned of by +noeh compiler option if desired. # define GTEST_HAS_EXCEPTIONS 1 # else // For other compilers, we assume exceptions are disabled to be // conservative. # define GTEST_HAS_EXCEPTIONS 0 # endif // defined(_MSC_VER) || defined(__BORLANDC__) #endif // GTEST_HAS_EXCEPTIONS #if !defined(GTEST_HAS_STD_STRING) // Even though we don't use this macro any longer, we keep it in case // some clients still depend on it. # define GTEST_HAS_STD_STRING 1 #elif !GTEST_HAS_STD_STRING // The user told us that ::std::string isn't available. # error "::std::string isn't available." #endif // !defined(GTEST_HAS_STD_STRING) #ifndef GTEST_HAS_STD_WSTRING // The user didn't tell us whether ::std::wstring is available, so we need // to figure it out. // Cygwin 1.7 and below doesn't support ::std::wstring. // Solaris' libc++ doesn't support it either. Android has // no support for it at least as recent as Froyo (2.2). #define GTEST_HAS_STD_WSTRING \ (!(GTEST_OS_LINUX_ANDROID || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS || \ GTEST_OS_HAIKU)) #endif // GTEST_HAS_STD_WSTRING // Determines whether RTTI is available. #ifndef GTEST_HAS_RTTI // The user didn't tell us whether RTTI is enabled, so we need to // figure it out. # ifdef _MSC_VER #ifdef _CPPRTTI // MSVC defines this macro if and only if RTTI is enabled. # define GTEST_HAS_RTTI 1 # else # define GTEST_HAS_RTTI 0 # endif // Starting with version 4.3.2, gcc defines __GXX_RTTI if and only if RTTI is // enabled. # elif defined(__GNUC__) # ifdef __GXX_RTTI // When building against STLport with the Android NDK and with // -frtti -fno-exceptions, the build fails at link time with undefined // references to __cxa_bad_typeid. Note sure if STL or toolchain bug, // so disable RTTI when detected. # if GTEST_OS_LINUX_ANDROID && defined(_STLPORT_MAJOR) && \ !defined(__EXCEPTIONS) # define GTEST_HAS_RTTI 0 # else # define GTEST_HAS_RTTI 1 # endif // GTEST_OS_LINUX_ANDROID && __STLPORT_MAJOR && !__EXCEPTIONS # else # define GTEST_HAS_RTTI 0 # endif // __GXX_RTTI // Clang defines __GXX_RTTI starting with version 3.0, but its manual recommends // using has_feature instead. has_feature(cxx_rtti) is supported since 2.7, the // first version with C++ support. # elif defined(__clang__) # define GTEST_HAS_RTTI __has_feature(cxx_rtti) // Starting with version 9.0 IBM Visual Age defines __RTTI_ALL__ to 1 if // both the typeid and dynamic_cast features are present. # elif defined(__IBMCPP__) && (__IBMCPP__ >= 900) # ifdef __RTTI_ALL__ # define GTEST_HAS_RTTI 1 # else # define GTEST_HAS_RTTI 0 # endif # else // For all other compilers, we assume RTTI is enabled. # define GTEST_HAS_RTTI 1 # endif // _MSC_VER #endif // GTEST_HAS_RTTI // It's this header's responsibility to #include when RTTI // is enabled. #if GTEST_HAS_RTTI # include #endif // Determines whether Google Test can use the pthreads library. #ifndef GTEST_HAS_PTHREAD // The user didn't tell us explicitly, so we make reasonable assumptions about // which platforms have pthreads support. // // To disable threading support in Google Test, add -DGTEST_HAS_PTHREAD=0 // to your compiler flags. #define GTEST_HAS_PTHREAD \ (GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_HPUX || GTEST_OS_QNX || \ GTEST_OS_FREEBSD || GTEST_OS_NACL || GTEST_OS_NETBSD || GTEST_OS_FUCHSIA || \ GTEST_OS_DRAGONFLY || GTEST_OS_GNU_KFREEBSD || GTEST_OS_OPENBSD || \ GTEST_OS_HAIKU) #endif // GTEST_HAS_PTHREAD #if GTEST_HAS_PTHREAD // gtest-port.h guarantees to #include when GTEST_HAS_PTHREAD is // true. # include // NOLINT // For timespec and nanosleep, used below. # include // NOLINT #endif // Determines whether clone(2) is supported. // Usually it will only be available on Linux, excluding // Linux on the Itanium architecture. // Also see http://linux.die.net/man/2/clone. #ifndef GTEST_HAS_CLONE // The user didn't tell us, so we need to figure it out. # if GTEST_OS_LINUX && !defined(__ia64__) # if GTEST_OS_LINUX_ANDROID // On Android, clone() became available at different API levels for each 32-bit // architecture. # if defined(__LP64__) || \ (defined(__arm__) && __ANDROID_API__ >= 9) || \ (defined(__mips__) && __ANDROID_API__ >= 12) || \ (defined(__i386__) && __ANDROID_API__ >= 17) # define GTEST_HAS_CLONE 1 # else # define GTEST_HAS_CLONE 0 # endif # else # define GTEST_HAS_CLONE 1 # endif # else # define GTEST_HAS_CLONE 0 # endif // GTEST_OS_LINUX && !defined(__ia64__) #endif // GTEST_HAS_CLONE // Determines whether to support stream redirection. This is used to test // output correctness and to implement death tests. #ifndef GTEST_HAS_STREAM_REDIRECTION // By default, we assume that stream redirection is supported on all // platforms except known mobile ones. # if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_PHONE || GTEST_OS_WINDOWS_RT # define GTEST_HAS_STREAM_REDIRECTION 0 # else # define GTEST_HAS_STREAM_REDIRECTION 1 # endif // !GTEST_OS_WINDOWS_MOBILE #endif // GTEST_HAS_STREAM_REDIRECTION // Determines whether to support death tests. // pops up a dialog window that cannot be suppressed programmatically. #if (GTEST_OS_LINUX || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS || \ (GTEST_OS_MAC && !GTEST_OS_IOS) || \ (GTEST_OS_WINDOWS_DESKTOP && _MSC_VER) || GTEST_OS_WINDOWS_MINGW || \ GTEST_OS_AIX || GTEST_OS_HPUX || GTEST_OS_OPENBSD || GTEST_OS_QNX || \ GTEST_OS_FREEBSD || GTEST_OS_NETBSD || GTEST_OS_FUCHSIA || \ GTEST_OS_DRAGONFLY || GTEST_OS_GNU_KFREEBSD || GTEST_OS_HAIKU) # define GTEST_HAS_DEATH_TEST 1 #endif // Determines whether to support type-driven tests. // Typed tests need and variadic macros, which GCC, VC++ 8.0, // Sun Pro CC, IBM Visual Age, and HP aCC support. #if defined(__GNUC__) || defined(_MSC_VER) || defined(__SUNPRO_CC) || \ defined(__IBMCPP__) || defined(__HP_aCC) # define GTEST_HAS_TYPED_TEST 1 # define GTEST_HAS_TYPED_TEST_P 1 #endif // Determines whether the system compiler uses UTF-16 for encoding wide strings. #define GTEST_WIDE_STRING_USES_UTF16_ \ (GTEST_OS_WINDOWS || GTEST_OS_CYGWIN || GTEST_OS_AIX || GTEST_OS_OS2) // Determines whether test results can be streamed to a socket. #if GTEST_OS_LINUX || GTEST_OS_GNU_KFREEBSD || GTEST_OS_DRAGONFLY || \ GTEST_OS_FREEBSD || GTEST_OS_NETBSD || GTEST_OS_OPENBSD # define GTEST_CAN_STREAM_RESULTS_ 1 #endif // Defines some utility macros. // The GNU compiler emits a warning if nested "if" statements are followed by // an "else" statement and braces are not used to explicitly disambiguate the // "else" binding. This leads to problems with code like: // // if (gate) // ASSERT_*(condition) << "Some message"; // // The "switch (0) case 0:" idiom is used to suppress this. #ifdef __INTEL_COMPILER # define GTEST_AMBIGUOUS_ELSE_BLOCKER_ #else # define GTEST_AMBIGUOUS_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT #endif // Use this annotation at the end of a struct/class definition to // prevent the compiler from optimizing away instances that are never // used. This is useful when all interesting logic happens inside the // c'tor and / or d'tor. Example: // // struct Foo { // Foo() { ... } // } GTEST_ATTRIBUTE_UNUSED_; // // Also use it after a variable or parameter declaration to tell the // compiler the variable/parameter does not have to be used. #if defined(__GNUC__) && !defined(COMPILER_ICC) # define GTEST_ATTRIBUTE_UNUSED_ __attribute__ ((unused)) #elif defined(__clang__) # if __has_attribute(unused) # define GTEST_ATTRIBUTE_UNUSED_ __attribute__ ((unused)) # endif #endif #ifndef GTEST_ATTRIBUTE_UNUSED_ # define GTEST_ATTRIBUTE_UNUSED_ #endif // Use this annotation before a function that takes a printf format string. #if (defined(__GNUC__) || defined(__clang__)) && !defined(COMPILER_ICC) # if defined(__MINGW_PRINTF_FORMAT) // MinGW has two different printf implementations. Ensure the format macro // matches the selected implementation. See // https://sourceforge.net/p/mingw-w64/wiki2/gnu%20printf/. # define GTEST_ATTRIBUTE_PRINTF_(string_index, first_to_check) \ __attribute__((__format__(__MINGW_PRINTF_FORMAT, string_index, \ first_to_check))) # else # define GTEST_ATTRIBUTE_PRINTF_(string_index, first_to_check) \ __attribute__((__format__(__printf__, string_index, first_to_check))) # endif #else # define GTEST_ATTRIBUTE_PRINTF_(string_index, first_to_check) #endif // A macro to disallow operator= // This should be used in the private: declarations for a class. #define GTEST_DISALLOW_ASSIGN_(type) \ void operator=(type const &) = delete // A macro to disallow copy constructor and operator= // This should be used in the private: declarations for a class. #define GTEST_DISALLOW_COPY_AND_ASSIGN_(type) \ type(type const &) = delete; \ GTEST_DISALLOW_ASSIGN_(type) // Tell the compiler to warn about unused return values for functions declared // with this macro. The macro should be used on function declarations // following the argument list: // // Sprocket* AllocateSprocket() GTEST_MUST_USE_RESULT_; #if defined(__GNUC__) && !defined(COMPILER_ICC) # define GTEST_MUST_USE_RESULT_ __attribute__ ((warn_unused_result)) #else # define GTEST_MUST_USE_RESULT_ #endif // __GNUC__ && !COMPILER_ICC // MS C++ compiler emits warning when a conditional expression is compile time // constant. In some contexts this warning is false positive and needs to be // suppressed. Use the following two macros in such cases: // // GTEST_INTENTIONAL_CONST_COND_PUSH_() // while (true) { // GTEST_INTENTIONAL_CONST_COND_POP_() // } # define GTEST_INTENTIONAL_CONST_COND_PUSH_() \ GTEST_DISABLE_MSC_WARNINGS_PUSH_(4127) # define GTEST_INTENTIONAL_CONST_COND_POP_() \ GTEST_DISABLE_MSC_WARNINGS_POP_() // Determine whether the compiler supports Microsoft's Structured Exception // Handling. This is supported by several Windows compilers but generally // does not exist on any other system. #ifndef GTEST_HAS_SEH // The user didn't tell us, so we need to figure it out. # if defined(_MSC_VER) || defined(__BORLANDC__) // These two compilers are known to support SEH. # define GTEST_HAS_SEH 1 # else // Assume no SEH. # define GTEST_HAS_SEH 0 # endif #endif // GTEST_HAS_SEH #ifndef GTEST_IS_THREADSAFE #define GTEST_IS_THREADSAFE \ (GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ || \ (GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT) || \ GTEST_HAS_PTHREAD) #endif // GTEST_IS_THREADSAFE // GTEST_API_ qualifies all symbols that must be exported. The definitions below // are guarded by #ifndef to give embedders a chance to define GTEST_API_ in // gtest/internal/custom/gtest-port.h #ifndef GTEST_API_ #ifdef _MSC_VER # if GTEST_LINKED_AS_SHARED_LIBRARY # define GTEST_API_ __declspec(dllimport) # elif GTEST_CREATE_SHARED_LIBRARY # define GTEST_API_ __declspec(dllexport) # endif #elif __GNUC__ >= 4 || defined(__clang__) # define GTEST_API_ __attribute__((visibility ("default"))) #endif // _MSC_VER #endif // GTEST_API_ #ifndef GTEST_API_ # define GTEST_API_ #endif // GTEST_API_ #ifndef GTEST_DEFAULT_DEATH_TEST_STYLE # define GTEST_DEFAULT_DEATH_TEST_STYLE "fast" #endif // GTEST_DEFAULT_DEATH_TEST_STYLE #ifdef __GNUC__ // Ask the compiler to never inline a given function. # define GTEST_NO_INLINE_ __attribute__((noinline)) #else # define GTEST_NO_INLINE_ #endif // _LIBCPP_VERSION is defined by the libc++ library from the LLVM project. #if !defined(GTEST_HAS_CXXABI_H_) # if defined(__GLIBCXX__) || (defined(_LIBCPP_VERSION) && !defined(_MSC_VER)) # define GTEST_HAS_CXXABI_H_ 1 # else # define GTEST_HAS_CXXABI_H_ 0 # endif #endif // A function level attribute to disable checking for use of uninitialized // memory when built with MemorySanitizer. #if defined(__clang__) # if __has_feature(memory_sanitizer) # define GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ \ __attribute__((no_sanitize_memory)) # else # define GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ # endif // __has_feature(memory_sanitizer) #else # define GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ #endif // __clang__ // A function level attribute to disable AddressSanitizer instrumentation. #if defined(__clang__) # if __has_feature(address_sanitizer) # define GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ \ __attribute__((no_sanitize_address)) # else # define GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ # endif // __has_feature(address_sanitizer) #else # define GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ #endif // __clang__ // A function level attribute to disable HWAddressSanitizer instrumentation. #if defined(__clang__) # if __has_feature(hwaddress_sanitizer) # define GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ \ __attribute__((no_sanitize("hwaddress"))) # else # define GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ # endif // __has_feature(hwaddress_sanitizer) #else # define GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ #endif // __clang__ // A function level attribute to disable ThreadSanitizer instrumentation. #if defined(__clang__) # if __has_feature(thread_sanitizer) # define GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ \ __attribute__((no_sanitize_thread)) # else # define GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ # endif // __has_feature(thread_sanitizer) #else # define GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ #endif // __clang__ namespace testing { class Message; // Legacy imports for backwards compatibility. // New code should use std:: names directly. using std::get; using std::make_tuple; using std::tuple; using std::tuple_element; using std::tuple_size; namespace internal { // A secret type that Google Test users don't know about. It has no // definition on purpose. Therefore it's impossible to create a // Secret object, which is what we want. class Secret; // The GTEST_COMPILE_ASSERT_ is a legacy macro used to verify that a compile // time expression is true (in new code, use static_assert instead). For // example, you could use it to verify the size of a static array: // // GTEST_COMPILE_ASSERT_(GTEST_ARRAY_SIZE_(names) == NUM_NAMES, // names_incorrect_size); // // The second argument to the macro must be a valid C++ identifier. If the // expression is false, compiler will issue an error containing this identifier. #define GTEST_COMPILE_ASSERT_(expr, msg) static_assert(expr, #msg) // Evaluates to the number of elements in 'array'. #define GTEST_ARRAY_SIZE_(array) (sizeof(array) / sizeof(array[0])) // A helper for suppressing warnings on constant condition. It just // returns 'condition'. GTEST_API_ bool IsTrue(bool condition); // Defines RE. #if GTEST_USES_PCRE // if used, PCRE is injected by custom/gtest-port.h #elif GTEST_USES_POSIX_RE || GTEST_USES_SIMPLE_RE // A simple C++ wrapper for . It uses the POSIX Extended // Regular Expression syntax. class GTEST_API_ RE { public: // A copy constructor is required by the Standard to initialize object // references from r-values. RE(const RE& other) { Init(other.pattern()); } // Constructs an RE from a string. RE(const ::std::string& regex) { Init(regex.c_str()); } // NOLINT RE(const char* regex) { Init(regex); } // NOLINT ~RE(); // Returns the string representation of the regex. const char* pattern() const { return pattern_; } // FullMatch(str, re) returns true if and only if regular expression re // matches the entire str. // PartialMatch(str, re) returns true if and only if regular expression re // matches a substring of str (including str itself). static bool FullMatch(const ::std::string& str, const RE& re) { return FullMatch(str.c_str(), re); } static bool PartialMatch(const ::std::string& str, const RE& re) { return PartialMatch(str.c_str(), re); } static bool FullMatch(const char* str, const RE& re); static bool PartialMatch(const char* str, const RE& re); private: void Init(const char* regex); const char* pattern_; bool is_valid_; # if GTEST_USES_POSIX_RE regex_t full_regex_; // For FullMatch(). regex_t partial_regex_; // For PartialMatch(). # else // GTEST_USES_SIMPLE_RE const char* full_pattern_; // For FullMatch(); # endif GTEST_DISALLOW_ASSIGN_(RE); }; #endif // GTEST_USES_PCRE // Formats a source file path and a line number as they would appear // in an error message from the compiler used to compile this code. GTEST_API_ ::std::string FormatFileLocation(const char* file, int line); // Formats a file location for compiler-independent XML output. // Although this function is not platform dependent, we put it next to // FormatFileLocation in order to contrast the two functions. GTEST_API_ ::std::string FormatCompilerIndependentFileLocation(const char* file, int line); // Defines logging utilities: // GTEST_LOG_(severity) - logs messages at the specified severity level. The // message itself is streamed into the macro. // LogToStderr() - directs all log messages to stderr. // FlushInfoLog() - flushes informational log messages. enum GTestLogSeverity { GTEST_INFO, GTEST_WARNING, GTEST_ERROR, GTEST_FATAL }; // Formats log entry severity, provides a stream object for streaming the // log message, and terminates the message with a newline when going out of // scope. class GTEST_API_ GTestLog { public: GTestLog(GTestLogSeverity severity, const char* file, int line); // Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. ~GTestLog(); ::std::ostream& GetStream() { return ::std::cerr; } private: const GTestLogSeverity severity_; GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestLog); }; #if !defined(GTEST_LOG_) # define GTEST_LOG_(severity) \ ::testing::internal::GTestLog(::testing::internal::GTEST_##severity, \ __FILE__, __LINE__).GetStream() inline void LogToStderr() {} inline void FlushInfoLog() { fflush(nullptr); } #endif // !defined(GTEST_LOG_) #if !defined(GTEST_CHECK_) // INTERNAL IMPLEMENTATION - DO NOT USE. // // GTEST_CHECK_ is an all-mode assert. It aborts the program if the condition // is not satisfied. // Synopsys: // GTEST_CHECK_(boolean_condition); // or // GTEST_CHECK_(boolean_condition) << "Additional message"; // // This checks the condition and if the condition is not satisfied // it prints message about the condition violation, including the // condition itself, plus additional message streamed into it, if any, // and then it aborts the program. It aborts the program irrespective of // whether it is built in the debug mode or not. # define GTEST_CHECK_(condition) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::IsTrue(condition)) \ ; \ else \ GTEST_LOG_(FATAL) << "Condition " #condition " failed. " #endif // !defined(GTEST_CHECK_) // An all-mode assert to verify that the given POSIX-style function // call returns 0 (indicating success). Known limitation: this // doesn't expand to a balanced 'if' statement, so enclose the macro // in {} if you need to use it as the only statement in an 'if' // branch. #define GTEST_CHECK_POSIX_SUCCESS_(posix_call) \ if (const int gtest_error = (posix_call)) \ GTEST_LOG_(FATAL) << #posix_call << "failed with error " \ << gtest_error // Transforms "T" into "const T&" according to standard reference collapsing // rules (this is only needed as a backport for C++98 compilers that do not // support reference collapsing). Specifically, it transforms: // // char ==> const char& // const char ==> const char& // char& ==> char& // const char& ==> const char& // // Note that the non-const reference will not have "const" added. This is // standard, and necessary so that "T" can always bind to "const T&". template struct ConstRef { typedef const T& type; }; template struct ConstRef { typedef T& type; }; // The argument T must depend on some template parameters. #define GTEST_REFERENCE_TO_CONST_(T) \ typename ::testing::internal::ConstRef::type // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // Use ImplicitCast_ as a safe version of static_cast for upcasting in // the type hierarchy (e.g. casting a Foo* to a SuperclassOfFoo* or a // const Foo*). When you use ImplicitCast_, the compiler checks that // the cast is safe. Such explicit ImplicitCast_s are necessary in // surprisingly many situations where C++ demands an exact type match // instead of an argument type convertable to a target type. // // The syntax for using ImplicitCast_ is the same as for static_cast: // // ImplicitCast_(expr) // // ImplicitCast_ would have been part of the C++ standard library, // but the proposal was submitted too late. It will probably make // its way into the language in the future. // // This relatively ugly name is intentional. It prevents clashes with // similar functions users may have (e.g., implicit_cast). The internal // namespace alone is not enough because the function can be found by ADL. template inline To ImplicitCast_(To x) { return x; } // When you upcast (that is, cast a pointer from type Foo to type // SuperclassOfFoo), it's fine to use ImplicitCast_<>, since upcasts // always succeed. When you downcast (that is, cast a pointer from // type Foo to type SubclassOfFoo), static_cast<> isn't safe, because // how do you know the pointer is really of type SubclassOfFoo? It // could be a bare Foo, or of type DifferentSubclassOfFoo. Thus, // when you downcast, you should use this macro. In debug mode, we // use dynamic_cast<> to double-check the downcast is legal (we die // if it's not). In normal mode, we do the efficient static_cast<> // instead. Thus, it's important to test in debug mode to make sure // the cast is legal! // This is the only place in the code we should use dynamic_cast<>. // In particular, you SHOULDN'T be using dynamic_cast<> in order to // do RTTI (eg code like this: // if (dynamic_cast(foo)) HandleASubclass1Object(foo); // if (dynamic_cast(foo)) HandleASubclass2Object(foo); // You should design the code some other way not to need this. // // This relatively ugly name is intentional. It prevents clashes with // similar functions users may have (e.g., down_cast). The internal // namespace alone is not enough because the function can be found by ADL. template // use like this: DownCast_(foo); inline To DownCast_(From* f) { // so we only accept pointers // Ensures that To is a sub-type of From *. This test is here only // for compile-time type checking, and has no overhead in an // optimized build at run-time, as it will be optimized away // completely. GTEST_INTENTIONAL_CONST_COND_PUSH_() if (false) { GTEST_INTENTIONAL_CONST_COND_POP_() const To to = nullptr; ::testing::internal::ImplicitCast_(to); } #if GTEST_HAS_RTTI // RTTI: debug mode only! GTEST_CHECK_(f == nullptr || dynamic_cast(f) != nullptr); #endif return static_cast(f); } // Downcasts the pointer of type Base to Derived. // Derived must be a subclass of Base. The parameter MUST // point to a class of type Derived, not any subclass of it. // When RTTI is available, the function performs a runtime // check to enforce this. template Derived* CheckedDowncastToActualType(Base* base) { #if GTEST_HAS_RTTI GTEST_CHECK_(typeid(*base) == typeid(Derived)); #endif #if GTEST_HAS_DOWNCAST_ return ::down_cast(base); #elif GTEST_HAS_RTTI return dynamic_cast(base); // NOLINT #else return static_cast(base); // Poor man's downcast. #endif } #if GTEST_HAS_STREAM_REDIRECTION // Defines the stderr capturer: // CaptureStdout - starts capturing stdout. // GetCapturedStdout - stops capturing stdout and returns the captured string. // CaptureStderr - starts capturing stderr. // GetCapturedStderr - stops capturing stderr and returns the captured string. // GTEST_API_ void CaptureStdout(); GTEST_API_ std::string GetCapturedStdout(); GTEST_API_ void CaptureStderr(); GTEST_API_ std::string GetCapturedStderr(); #endif // GTEST_HAS_STREAM_REDIRECTION // Returns the size (in bytes) of a file. GTEST_API_ size_t GetFileSize(FILE* file); // Reads the entire content of a file as a string. GTEST_API_ std::string ReadEntireFile(FILE* file); // All command line arguments. GTEST_API_ std::vector GetArgvs(); #if GTEST_HAS_DEATH_TEST std::vector GetInjectableArgvs(); // Deprecated: pass the args vector by value instead. void SetInjectableArgvs(const std::vector* new_argvs); void SetInjectableArgvs(const std::vector& new_argvs); void ClearInjectableArgvs(); #endif // GTEST_HAS_DEATH_TEST // Defines synchronization primitives. #if GTEST_IS_THREADSAFE # if GTEST_HAS_PTHREAD // Sleeps for (roughly) n milliseconds. This function is only for testing // Google Test's own constructs. Don't use it in user tests, either // directly or indirectly. inline void SleepMilliseconds(int n) { const timespec time = { 0, // 0 seconds. n * 1000L * 1000L, // And n ms. }; nanosleep(&time, nullptr); } # endif // GTEST_HAS_PTHREAD # if GTEST_HAS_NOTIFICATION_ // Notification has already been imported into the namespace. // Nothing to do here. # elif GTEST_HAS_PTHREAD // Allows a controller thread to pause execution of newly created // threads until notified. Instances of this class must be created // and destroyed in the controller thread. // // This class is only for testing Google Test's own constructs. Do not // use it in user tests, either directly or indirectly. class Notification { public: Notification() : notified_(false) { GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_init(&mutex_, nullptr)); } ~Notification() { pthread_mutex_destroy(&mutex_); } // Notifies all threads created with this notification to start. Must // be called from the controller thread. void Notify() { pthread_mutex_lock(&mutex_); notified_ = true; pthread_mutex_unlock(&mutex_); } // Blocks until the controller thread notifies. Must be called from a test // thread. void WaitForNotification() { for (;;) { pthread_mutex_lock(&mutex_); const bool notified = notified_; pthread_mutex_unlock(&mutex_); if (notified) break; SleepMilliseconds(10); } } private: pthread_mutex_t mutex_; bool notified_; GTEST_DISALLOW_COPY_AND_ASSIGN_(Notification); }; # elif GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT GTEST_API_ void SleepMilliseconds(int n); // Provides leak-safe Windows kernel handle ownership. // Used in death tests and in threading support. class GTEST_API_ AutoHandle { public: // Assume that Win32 HANDLE type is equivalent to void*. Doing so allows us to // avoid including in this header file. Including is // undesirable because it defines a lot of symbols and macros that tend to // conflict with client code. This assumption is verified by // WindowsTypesTest.HANDLEIsVoidStar. typedef void* Handle; AutoHandle(); explicit AutoHandle(Handle handle); ~AutoHandle(); Handle Get() const; void Reset(); void Reset(Handle handle); private: // Returns true if and only if the handle is a valid handle object that can be // closed. bool IsCloseable() const; Handle handle_; GTEST_DISALLOW_COPY_AND_ASSIGN_(AutoHandle); }; // Allows a controller thread to pause execution of newly created // threads until notified. Instances of this class must be created // and destroyed in the controller thread. // // This class is only for testing Google Test's own constructs. Do not // use it in user tests, either directly or indirectly. class GTEST_API_ Notification { public: Notification(); void Notify(); void WaitForNotification(); private: AutoHandle event_; GTEST_DISALLOW_COPY_AND_ASSIGN_(Notification); }; # endif // GTEST_HAS_NOTIFICATION_ // On MinGW, we can have both GTEST_OS_WINDOWS and GTEST_HAS_PTHREAD // defined, but we don't want to use MinGW's pthreads implementation, which // has conformance problems with some versions of the POSIX standard. # if GTEST_HAS_PTHREAD && !GTEST_OS_WINDOWS_MINGW // As a C-function, ThreadFuncWithCLinkage cannot be templated itself. // Consequently, it cannot select a correct instantiation of ThreadWithParam // in order to call its Run(). Introducing ThreadWithParamBase as a // non-templated base class for ThreadWithParam allows us to bypass this // problem. class ThreadWithParamBase { public: virtual ~ThreadWithParamBase() {} virtual void Run() = 0; }; // pthread_create() accepts a pointer to a function type with the C linkage. // According to the Standard (7.5/1), function types with different linkages // are different even if they are otherwise identical. Some compilers (for // example, SunStudio) treat them as different types. Since class methods // cannot be defined with C-linkage we need to define a free C-function to // pass into pthread_create(). extern "C" inline void* ThreadFuncWithCLinkage(void* thread) { static_cast(thread)->Run(); return nullptr; } // Helper class for testing Google Test's multi-threading constructs. // To use it, write: // // void ThreadFunc(int param) { /* Do things with param */ } // Notification thread_can_start; // ... // // The thread_can_start parameter is optional; you can supply NULL. // ThreadWithParam thread(&ThreadFunc, 5, &thread_can_start); // thread_can_start.Notify(); // // These classes are only for testing Google Test's own constructs. Do // not use them in user tests, either directly or indirectly. template class ThreadWithParam : public ThreadWithParamBase { public: typedef void UserThreadFunc(T); ThreadWithParam(UserThreadFunc* func, T param, Notification* thread_can_start) : func_(func), param_(param), thread_can_start_(thread_can_start), finished_(false) { ThreadWithParamBase* const base = this; // The thread can be created only after all fields except thread_ // have been initialized. GTEST_CHECK_POSIX_SUCCESS_( pthread_create(&thread_, nullptr, &ThreadFuncWithCLinkage, base)); } ~ThreadWithParam() override { Join(); } void Join() { if (!finished_) { GTEST_CHECK_POSIX_SUCCESS_(pthread_join(thread_, nullptr)); finished_ = true; } } void Run() override { if (thread_can_start_ != nullptr) thread_can_start_->WaitForNotification(); func_(param_); } private: UserThreadFunc* const func_; // User-supplied thread function. const T param_; // User-supplied parameter to the thread function. // When non-NULL, used to block execution until the controller thread // notifies. Notification* const thread_can_start_; bool finished_; // true if and only if we know that the thread function has // finished. pthread_t thread_; // The native thread object. GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParam); }; # endif // !GTEST_OS_WINDOWS && GTEST_HAS_PTHREAD || // GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ # if GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ // Mutex and ThreadLocal have already been imported into the namespace. // Nothing to do here. # elif GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT // Mutex implements mutex on Windows platforms. It is used in conjunction // with class MutexLock: // // Mutex mutex; // ... // MutexLock lock(&mutex); // Acquires the mutex and releases it at the // // end of the current scope. // // A static Mutex *must* be defined or declared using one of the following // macros: // GTEST_DEFINE_STATIC_MUTEX_(g_some_mutex); // GTEST_DECLARE_STATIC_MUTEX_(g_some_mutex); // // (A non-static Mutex is defined/declared in the usual way). class GTEST_API_ Mutex { public: enum MutexType { kStatic = 0, kDynamic = 1 }; // We rely on kStaticMutex being 0 as it is to what the linker initializes // type_ in static mutexes. critical_section_ will be initialized lazily // in ThreadSafeLazyInit(). enum StaticConstructorSelector { kStaticMutex = 0 }; // This constructor intentionally does nothing. It relies on type_ being // statically initialized to 0 (effectively setting it to kStatic) and on // ThreadSafeLazyInit() to lazily initialize the rest of the members. explicit Mutex(StaticConstructorSelector /*dummy*/) {} Mutex(); ~Mutex(); void Lock(); void Unlock(); // Does nothing if the current thread holds the mutex. Otherwise, crashes // with high probability. void AssertHeld(); private: // Initializes owner_thread_id_ and critical_section_ in static mutexes. void ThreadSafeLazyInit(); // Per https://blogs.msdn.microsoft.com/oldnewthing/20040223-00/?p=40503, // we assume that 0 is an invalid value for thread IDs. unsigned int owner_thread_id_; // For static mutexes, we rely on these members being initialized to zeros // by the linker. MutexType type_; long critical_section_init_phase_; // NOLINT GTEST_CRITICAL_SECTION* critical_section_; GTEST_DISALLOW_COPY_AND_ASSIGN_(Mutex); }; # define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ extern ::testing::internal::Mutex mutex # define GTEST_DEFINE_STATIC_MUTEX_(mutex) \ ::testing::internal::Mutex mutex(::testing::internal::Mutex::kStaticMutex) // We cannot name this class MutexLock because the ctor declaration would // conflict with a macro named MutexLock, which is defined on some // platforms. That macro is used as a defensive measure to prevent against // inadvertent misuses of MutexLock like "MutexLock(&mu)" rather than // "MutexLock l(&mu)". Hence the typedef trick below. class GTestMutexLock { public: explicit GTestMutexLock(Mutex* mutex) : mutex_(mutex) { mutex_->Lock(); } ~GTestMutexLock() { mutex_->Unlock(); } private: Mutex* const mutex_; GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestMutexLock); }; typedef GTestMutexLock MutexLock; // Base class for ValueHolder. Allows a caller to hold and delete a value // without knowing its type. class ThreadLocalValueHolderBase { public: virtual ~ThreadLocalValueHolderBase() {} }; // Provides a way for a thread to send notifications to a ThreadLocal // regardless of its parameter type. class ThreadLocalBase { public: // Creates a new ValueHolder object holding a default value passed to // this ThreadLocal's constructor and returns it. It is the caller's // responsibility not to call this when the ThreadLocal instance already // has a value on the current thread. virtual ThreadLocalValueHolderBase* NewValueForCurrentThread() const = 0; protected: ThreadLocalBase() {} virtual ~ThreadLocalBase() {} private: GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocalBase); }; // Maps a thread to a set of ThreadLocals that have values instantiated on that // thread and notifies them when the thread exits. A ThreadLocal instance is // expected to persist until all threads it has values on have terminated. class GTEST_API_ ThreadLocalRegistry { public: // Registers thread_local_instance as having value on the current thread. // Returns a value that can be used to identify the thread from other threads. static ThreadLocalValueHolderBase* GetValueOnCurrentThread( const ThreadLocalBase* thread_local_instance); // Invoked when a ThreadLocal instance is destroyed. static void OnThreadLocalDestroyed( const ThreadLocalBase* thread_local_instance); }; class GTEST_API_ ThreadWithParamBase { public: void Join(); protected: class Runnable { public: virtual ~Runnable() {} virtual void Run() = 0; }; ThreadWithParamBase(Runnable *runnable, Notification* thread_can_start); virtual ~ThreadWithParamBase(); private: AutoHandle thread_; }; // Helper class for testing Google Test's multi-threading constructs. template class ThreadWithParam : public ThreadWithParamBase { public: typedef void UserThreadFunc(T); ThreadWithParam(UserThreadFunc* func, T param, Notification* thread_can_start) : ThreadWithParamBase(new RunnableImpl(func, param), thread_can_start) { } virtual ~ThreadWithParam() {} private: class RunnableImpl : public Runnable { public: RunnableImpl(UserThreadFunc* func, T param) : func_(func), param_(param) { } virtual ~RunnableImpl() {} virtual void Run() { func_(param_); } private: UserThreadFunc* const func_; const T param_; GTEST_DISALLOW_COPY_AND_ASSIGN_(RunnableImpl); }; GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParam); }; // Implements thread-local storage on Windows systems. // // // Thread 1 // ThreadLocal tl(100); // 100 is the default value for each thread. // // // Thread 2 // tl.set(150); // Changes the value for thread 2 only. // EXPECT_EQ(150, tl.get()); // // // Thread 1 // EXPECT_EQ(100, tl.get()); // In thread 1, tl has the original value. // tl.set(200); // EXPECT_EQ(200, tl.get()); // // The template type argument T must have a public copy constructor. // In addition, the default ThreadLocal constructor requires T to have // a public default constructor. // // The users of a TheadLocal instance have to make sure that all but one // threads (including the main one) using that instance have exited before // destroying it. Otherwise, the per-thread objects managed for them by the // ThreadLocal instance are not guaranteed to be destroyed on all platforms. // // Google Test only uses global ThreadLocal objects. That means they // will die after main() has returned. Therefore, no per-thread // object managed by Google Test will be leaked as long as all threads // using Google Test have exited when main() returns. template class ThreadLocal : public ThreadLocalBase { public: ThreadLocal() : default_factory_(new DefaultValueHolderFactory()) {} explicit ThreadLocal(const T& value) : default_factory_(new InstanceValueHolderFactory(value)) {} ~ThreadLocal() { ThreadLocalRegistry::OnThreadLocalDestroyed(this); } T* pointer() { return GetOrCreateValue(); } const T* pointer() const { return GetOrCreateValue(); } const T& get() const { return *pointer(); } void set(const T& value) { *pointer() = value; } private: // Holds a value of T. Can be deleted via its base class without the caller // knowing the type of T. class ValueHolder : public ThreadLocalValueHolderBase { public: ValueHolder() : value_() {} explicit ValueHolder(const T& value) : value_(value) {} T* pointer() { return &value_; } private: T value_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolder); }; T* GetOrCreateValue() const { return static_cast( ThreadLocalRegistry::GetValueOnCurrentThread(this))->pointer(); } virtual ThreadLocalValueHolderBase* NewValueForCurrentThread() const { return default_factory_->MakeNewHolder(); } class ValueHolderFactory { public: ValueHolderFactory() {} virtual ~ValueHolderFactory() {} virtual ValueHolder* MakeNewHolder() const = 0; private: GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolderFactory); }; class DefaultValueHolderFactory : public ValueHolderFactory { public: DefaultValueHolderFactory() {} virtual ValueHolder* MakeNewHolder() const { return new ValueHolder(); } private: GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultValueHolderFactory); }; class InstanceValueHolderFactory : public ValueHolderFactory { public: explicit InstanceValueHolderFactory(const T& value) : value_(value) {} virtual ValueHolder* MakeNewHolder() const { return new ValueHolder(value_); } private: const T value_; // The value for each thread. GTEST_DISALLOW_COPY_AND_ASSIGN_(InstanceValueHolderFactory); }; std::unique_ptr default_factory_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocal); }; # elif GTEST_HAS_PTHREAD // MutexBase and Mutex implement mutex on pthreads-based platforms. class MutexBase { public: // Acquires this mutex. void Lock() { GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_lock(&mutex_)); owner_ = pthread_self(); has_owner_ = true; } // Releases this mutex. void Unlock() { // Since the lock is being released the owner_ field should no longer be // considered valid. We don't protect writing to has_owner_ here, as it's // the caller's responsibility to ensure that the current thread holds the // mutex when this is called. has_owner_ = false; GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_unlock(&mutex_)); } // Does nothing if the current thread holds the mutex. Otherwise, crashes // with high probability. void AssertHeld() const { GTEST_CHECK_(has_owner_ && pthread_equal(owner_, pthread_self())) << "The current thread is not holding the mutex @" << this; } // A static mutex may be used before main() is entered. It may even // be used before the dynamic initialization stage. Therefore we // must be able to initialize a static mutex object at link time. // This means MutexBase has to be a POD and its member variables // have to be public. public: pthread_mutex_t mutex_; // The underlying pthread mutex. // has_owner_ indicates whether the owner_ field below contains a valid thread // ID and is therefore safe to inspect (e.g., to use in pthread_equal()). All // accesses to the owner_ field should be protected by a check of this field. // An alternative might be to memset() owner_ to all zeros, but there's no // guarantee that a zero'd pthread_t is necessarily invalid or even different // from pthread_self(). bool has_owner_; pthread_t owner_; // The thread holding the mutex. }; // Forward-declares a static mutex. # define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ extern ::testing::internal::MutexBase mutex // Defines and statically (i.e. at link time) initializes a static mutex. // The initialization list here does not explicitly initialize each field, // instead relying on default initialization for the unspecified fields. In // particular, the owner_ field (a pthread_t) is not explicitly initialized. // This allows initialization to work whether pthread_t is a scalar or struct. // The flag -Wmissing-field-initializers must not be specified for this to work. #define GTEST_DEFINE_STATIC_MUTEX_(mutex) \ ::testing::internal::MutexBase mutex = {PTHREAD_MUTEX_INITIALIZER, false, 0} // The Mutex class can only be used for mutexes created at runtime. It // shares its API with MutexBase otherwise. class Mutex : public MutexBase { public: Mutex() { GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_init(&mutex_, nullptr)); has_owner_ = false; } ~Mutex() { GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_destroy(&mutex_)); } private: GTEST_DISALLOW_COPY_AND_ASSIGN_(Mutex); }; // We cannot name this class MutexLock because the ctor declaration would // conflict with a macro named MutexLock, which is defined on some // platforms. That macro is used as a defensive measure to prevent against // inadvertent misuses of MutexLock like "MutexLock(&mu)" rather than // "MutexLock l(&mu)". Hence the typedef trick below. class GTestMutexLock { public: explicit GTestMutexLock(MutexBase* mutex) : mutex_(mutex) { mutex_->Lock(); } ~GTestMutexLock() { mutex_->Unlock(); } private: MutexBase* const mutex_; GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestMutexLock); }; typedef GTestMutexLock MutexLock; // Helpers for ThreadLocal. // pthread_key_create() requires DeleteThreadLocalValue() to have // C-linkage. Therefore it cannot be templatized to access // ThreadLocal. Hence the need for class // ThreadLocalValueHolderBase. class ThreadLocalValueHolderBase { public: virtual ~ThreadLocalValueHolderBase() {} }; // Called by pthread to delete thread-local data stored by // pthread_setspecific(). extern "C" inline void DeleteThreadLocalValue(void* value_holder) { delete static_cast(value_holder); } // Implements thread-local storage on pthreads-based systems. template class GTEST_API_ ThreadLocal { public: ThreadLocal() : key_(CreateKey()), default_factory_(new DefaultValueHolderFactory()) {} explicit ThreadLocal(const T& value) : key_(CreateKey()), default_factory_(new InstanceValueHolderFactory(value)) {} ~ThreadLocal() { // Destroys the managed object for the current thread, if any. DeleteThreadLocalValue(pthread_getspecific(key_)); // Releases resources associated with the key. This will *not* // delete managed objects for other threads. GTEST_CHECK_POSIX_SUCCESS_(pthread_key_delete(key_)); } T* pointer() { return GetOrCreateValue(); } const T* pointer() const { return GetOrCreateValue(); } const T& get() const { return *pointer(); } void set(const T& value) { *pointer() = value; } private: // Holds a value of type T. class ValueHolder : public ThreadLocalValueHolderBase { public: ValueHolder() : value_() {} explicit ValueHolder(const T& value) : value_(value) {} T* pointer() { return &value_; } private: T value_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolder); }; static pthread_key_t CreateKey() { pthread_key_t key; // When a thread exits, DeleteThreadLocalValue() will be called on // the object managed for that thread. GTEST_CHECK_POSIX_SUCCESS_( pthread_key_create(&key, &DeleteThreadLocalValue)); return key; } T* GetOrCreateValue() const { ThreadLocalValueHolderBase* const holder = static_cast(pthread_getspecific(key_)); if (holder != nullptr) { return CheckedDowncastToActualType(holder)->pointer(); } ValueHolder* const new_holder = default_factory_->MakeNewHolder(); ThreadLocalValueHolderBase* const holder_base = new_holder; GTEST_CHECK_POSIX_SUCCESS_(pthread_setspecific(key_, holder_base)); return new_holder->pointer(); } class ValueHolderFactory { public: ValueHolderFactory() {} virtual ~ValueHolderFactory() {} virtual ValueHolder* MakeNewHolder() const = 0; private: GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolderFactory); }; class DefaultValueHolderFactory : public ValueHolderFactory { public: DefaultValueHolderFactory() {} virtual ValueHolder* MakeNewHolder() const { return new ValueHolder(); } private: GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultValueHolderFactory); }; class InstanceValueHolderFactory : public ValueHolderFactory { public: explicit InstanceValueHolderFactory(const T& value) : value_(value) {} virtual ValueHolder* MakeNewHolder() const { return new ValueHolder(value_); } private: const T value_; // The value for each thread. GTEST_DISALLOW_COPY_AND_ASSIGN_(InstanceValueHolderFactory); }; // A key pthreads uses for looking up per-thread values. const pthread_key_t key_; std::unique_ptr default_factory_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocal); }; # endif // GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ #else // GTEST_IS_THREADSAFE // A dummy implementation of synchronization primitives (mutex, lock, // and thread-local variable). Necessary for compiling Google Test where // mutex is not supported - using Google Test in multiple threads is not // supported on such platforms. class Mutex { public: Mutex() {} void Lock() {} void Unlock() {} void AssertHeld() const {} }; # define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ extern ::testing::internal::Mutex mutex # define GTEST_DEFINE_STATIC_MUTEX_(mutex) ::testing::internal::Mutex mutex // We cannot name this class MutexLock because the ctor declaration would // conflict with a macro named MutexLock, which is defined on some // platforms. That macro is used as a defensive measure to prevent against // inadvertent misuses of MutexLock like "MutexLock(&mu)" rather than // "MutexLock l(&mu)". Hence the typedef trick below. class GTestMutexLock { public: explicit GTestMutexLock(Mutex*) {} // NOLINT }; typedef GTestMutexLock MutexLock; template class GTEST_API_ ThreadLocal { public: ThreadLocal() : value_() {} explicit ThreadLocal(const T& value) : value_(value) {} T* pointer() { return &value_; } const T* pointer() const { return &value_; } const T& get() const { return value_; } void set(const T& value) { value_ = value; } private: T value_; }; #endif // GTEST_IS_THREADSAFE // Returns the number of threads running in the process, or 0 to indicate that // we cannot detect it. GTEST_API_ size_t GetThreadCount(); template using bool_constant = std::integral_constant; #if GTEST_OS_WINDOWS # define GTEST_PATH_SEP_ "\\" # define GTEST_HAS_ALT_PATH_SEP_ 1 // The biggest signed integer type the compiler supports. typedef __int64 BiggestInt; #else # define GTEST_PATH_SEP_ "/" # define GTEST_HAS_ALT_PATH_SEP_ 0 typedef long long BiggestInt; // NOLINT #endif // GTEST_OS_WINDOWS // Utilities for char. // isspace(int ch) and friends accept an unsigned char or EOF. char // may be signed, depending on the compiler (or compiler flags). // Therefore we need to cast a char to unsigned char before calling // isspace(), etc. inline bool IsAlpha(char ch) { return isalpha(static_cast(ch)) != 0; } inline bool IsAlNum(char ch) { return isalnum(static_cast(ch)) != 0; } inline bool IsDigit(char ch) { return isdigit(static_cast(ch)) != 0; } inline bool IsLower(char ch) { return islower(static_cast(ch)) != 0; } inline bool IsSpace(char ch) { return isspace(static_cast(ch)) != 0; } inline bool IsUpper(char ch) { return isupper(static_cast(ch)) != 0; } inline bool IsXDigit(char ch) { return isxdigit(static_cast(ch)) != 0; } inline bool IsXDigit(wchar_t ch) { const unsigned char low_byte = static_cast(ch); return ch == low_byte && isxdigit(low_byte) != 0; } inline char ToLower(char ch) { return static_cast(tolower(static_cast(ch))); } inline char ToUpper(char ch) { return static_cast(toupper(static_cast(ch))); } inline std::string StripTrailingSpaces(std::string str) { std::string::iterator it = str.end(); while (it != str.begin() && IsSpace(*--it)) it = str.erase(it); return str; } // The testing::internal::posix namespace holds wrappers for common // POSIX functions. These wrappers hide the differences between // Windows/MSVC and POSIX systems. Since some compilers define these // standard functions as macros, the wrapper cannot have the same name // as the wrapped function. namespace posix { // Functions with a different name on Windows. #if GTEST_OS_WINDOWS typedef struct _stat StatStruct; # ifdef __BORLANDC__ inline int IsATTY(int fd) { return isatty(fd); } inline int StrCaseCmp(const char* s1, const char* s2) { return stricmp(s1, s2); } inline char* StrDup(const char* src) { return strdup(src); } # else // !__BORLANDC__ # if GTEST_OS_WINDOWS_MOBILE inline int IsATTY(int /* fd */) { return 0; } # else inline int IsATTY(int fd) { return _isatty(fd); } # endif // GTEST_OS_WINDOWS_MOBILE inline int StrCaseCmp(const char* s1, const char* s2) { return _stricmp(s1, s2); } inline char* StrDup(const char* src) { return _strdup(src); } # endif // __BORLANDC__ # if GTEST_OS_WINDOWS_MOBILE inline int FileNo(FILE* file) { return reinterpret_cast(_fileno(file)); } // Stat(), RmDir(), and IsDir() are not needed on Windows CE at this // time and thus not defined there. # else inline int FileNo(FILE* file) { return _fileno(file); } inline int Stat(const char* path, StatStruct* buf) { return _stat(path, buf); } inline int RmDir(const char* dir) { return _rmdir(dir); } inline bool IsDir(const StatStruct& st) { return (_S_IFDIR & st.st_mode) != 0; } # endif // GTEST_OS_WINDOWS_MOBILE #else typedef struct stat StatStruct; inline int FileNo(FILE* file) { return fileno(file); } inline int IsATTY(int fd) { return isatty(fd); } inline int Stat(const char* path, StatStruct* buf) { return stat(path, buf); } inline int StrCaseCmp(const char* s1, const char* s2) { return strcasecmp(s1, s2); } inline char* StrDup(const char* src) { return strdup(src); } inline int RmDir(const char* dir) { return rmdir(dir); } inline bool IsDir(const StatStruct& st) { return S_ISDIR(st.st_mode); } #endif // GTEST_OS_WINDOWS // Functions deprecated by MSVC 8.0. GTEST_DISABLE_MSC_DEPRECATED_PUSH_() inline const char* StrNCpy(char* dest, const char* src, size_t n) { return strncpy(dest, src, n); } // ChDir(), FReopen(), FDOpen(), Read(), Write(), Close(), and // StrError() aren't needed on Windows CE at this time and thus not // defined there. #if !GTEST_OS_WINDOWS_MOBILE && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT inline int ChDir(const char* dir) { return chdir(dir); } #endif inline FILE* FOpen(const char* path, const char* mode) { return fopen(path, mode); } #if !GTEST_OS_WINDOWS_MOBILE inline FILE *FReopen(const char* path, const char* mode, FILE* stream) { return freopen(path, mode, stream); } inline FILE* FDOpen(int fd, const char* mode) { return fdopen(fd, mode); } #endif inline int FClose(FILE* fp) { return fclose(fp); } #if !GTEST_OS_WINDOWS_MOBILE inline int Read(int fd, void* buf, unsigned int count) { return static_cast(read(fd, buf, count)); } inline int Write(int fd, const void* buf, unsigned int count) { return static_cast(write(fd, buf, count)); } inline int Close(int fd) { return close(fd); } inline const char* StrError(int errnum) { return strerror(errnum); } #endif inline const char* GetEnv(const char* name) { #if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_PHONE || GTEST_OS_WINDOWS_RT // We are on Windows CE, which has no environment variables. static_cast(name); // To prevent 'unused argument' warning. return nullptr; #elif defined(__BORLANDC__) || defined(__SunOS_5_8) || defined(__SunOS_5_9) // Environment variables which we programmatically clear will be set to the // empty string rather than unset (NULL). Handle that case. const char* const env = getenv(name); return (env != nullptr && env[0] != '\0') ? env : nullptr; #else return getenv(name); #endif } GTEST_DISABLE_MSC_DEPRECATED_POP_() #if GTEST_OS_WINDOWS_MOBILE // Windows CE has no C library. The abort() function is used in // several places in Google Test. This implementation provides a reasonable // imitation of standard behaviour. [[noreturn]] void Abort(); #else [[noreturn]] inline void Abort() { abort(); } #endif // GTEST_OS_WINDOWS_MOBILE } // namespace posix // MSVC "deprecates" snprintf and issues warnings wherever it is used. In // order to avoid these warnings, we need to use _snprintf or _snprintf_s on // MSVC-based platforms. We map the GTEST_SNPRINTF_ macro to the appropriate // function in order to achieve that. We use macro definition here because // snprintf is a variadic function. #if _MSC_VER && !GTEST_OS_WINDOWS_MOBILE // MSVC 2005 and above support variadic macros. # define GTEST_SNPRINTF_(buffer, size, format, ...) \ _snprintf_s(buffer, size, size, format, __VA_ARGS__) #elif defined(_MSC_VER) // Windows CE does not define _snprintf_s # define GTEST_SNPRINTF_ _snprintf #else # define GTEST_SNPRINTF_ snprintf #endif // The maximum number a BiggestInt can represent. This definition // works no matter BiggestInt is represented in one's complement or // two's complement. // // We cannot rely on numeric_limits in STL, as __int64 and long long // are not part of standard C++ and numeric_limits doesn't need to be // defined for them. const BiggestInt kMaxBiggestInt = ~(static_cast(1) << (8*sizeof(BiggestInt) - 1)); // This template class serves as a compile-time function from size to // type. It maps a size in bytes to a primitive type with that // size. e.g. // // TypeWithSize<4>::UInt // // is typedef-ed to be unsigned int (unsigned integer made up of 4 // bytes). // // Such functionality should belong to STL, but I cannot find it // there. // // Google Test uses this class in the implementation of floating-point // comparison. // // For now it only handles UInt (unsigned int) as that's all Google Test // needs. Other types can be easily added in the future if need // arises. template class TypeWithSize { public: // This prevents the user from using TypeWithSize with incorrect // values of N. typedef void UInt; }; // The specialization for size 4. template <> class TypeWithSize<4> { public: // unsigned int has size 4 in both gcc and MSVC. // // As base/basictypes.h doesn't compile on Windows, we cannot use // uint32, uint64, and etc here. typedef int Int; typedef unsigned int UInt; }; // The specialization for size 8. template <> class TypeWithSize<8> { public: #if GTEST_OS_WINDOWS typedef __int64 Int; typedef unsigned __int64 UInt; #else typedef long long Int; // NOLINT typedef unsigned long long UInt; // NOLINT #endif // GTEST_OS_WINDOWS }; // Integer types of known sizes. typedef TypeWithSize<4>::Int Int32; typedef TypeWithSize<4>::UInt UInt32; typedef TypeWithSize<8>::Int Int64; typedef TypeWithSize<8>::UInt UInt64; typedef TypeWithSize<8>::Int TimeInMillis; // Represents time in milliseconds. // Utilities for command line flags and environment variables. // Macro for referencing flags. #if !defined(GTEST_FLAG) # define GTEST_FLAG(name) FLAGS_gtest_##name #endif // !defined(GTEST_FLAG) #if !defined(GTEST_USE_OWN_FLAGFILE_FLAG_) # define GTEST_USE_OWN_FLAGFILE_FLAG_ 1 #endif // !defined(GTEST_USE_OWN_FLAGFILE_FLAG_) #if !defined(GTEST_DECLARE_bool_) # define GTEST_FLAG_SAVER_ ::testing::internal::GTestFlagSaver // Macros for declaring flags. # define GTEST_DECLARE_bool_(name) GTEST_API_ extern bool GTEST_FLAG(name) # define GTEST_DECLARE_int32_(name) \ GTEST_API_ extern ::testing::internal::Int32 GTEST_FLAG(name) # define GTEST_DECLARE_string_(name) \ GTEST_API_ extern ::std::string GTEST_FLAG(name) // Macros for defining flags. # define GTEST_DEFINE_bool_(name, default_val, doc) \ GTEST_API_ bool GTEST_FLAG(name) = (default_val) # define GTEST_DEFINE_int32_(name, default_val, doc) \ GTEST_API_ ::testing::internal::Int32 GTEST_FLAG(name) = (default_val) # define GTEST_DEFINE_string_(name, default_val, doc) \ GTEST_API_ ::std::string GTEST_FLAG(name) = (default_val) #endif // !defined(GTEST_DECLARE_bool_) // Thread annotations #if !defined(GTEST_EXCLUSIVE_LOCK_REQUIRED_) # define GTEST_EXCLUSIVE_LOCK_REQUIRED_(locks) # define GTEST_LOCK_EXCLUDED_(locks) #endif // !defined(GTEST_EXCLUSIVE_LOCK_REQUIRED_) // Parses 'str' for a 32-bit signed integer. If successful, writes the result // to *value and returns true; otherwise leaves *value unchanged and returns // false. bool ParseInt32(const Message& src_text, const char* str, Int32* value); // Parses a bool/Int32/string from the environment variable // corresponding to the given Google Test flag. bool BoolFromGTestEnv(const char* flag, bool default_val); GTEST_API_ Int32 Int32FromGTestEnv(const char* flag, Int32 default_val); std::string OutputFlagAlsoCheckEnvVar(); const char* StringFromGTestEnv(const char* flag, const char* default_val); } // namespace internal } // namespace testing #if !defined(GTEST_INTERNAL_DEPRECATED) // Internal Macro to mark an API deprecated, for googletest usage only // Usage: class GTEST_INTERNAL_DEPRECATED(message) MyClass or // GTEST_INTERNAL_DEPRECATED(message) myFunction(); Every usage of // a deprecated entity will trigger a warning when compiled with // `-Wdeprecated-declarations` option (clang, gcc, any __GNUC__ compiler). // For msvc /W3 option will need to be used // Note that for 'other' compilers this macro evaluates to nothing to prevent // compilations errors. #if defined(_MSC_VER) #define GTEST_INTERNAL_DEPRECATED(message) __declspec(deprecated(message)) #elif defined(__GNUC__) #define GTEST_INTERNAL_DEPRECATED(message) __attribute__((deprecated(message))) #else #define GTEST_INTERNAL_DEPRECATED(message) #endif #endif // !defined(GTEST_INTERNAL_DEPRECATED) #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ #if GTEST_OS_LINUX # include # include # include # include #endif // GTEST_OS_LINUX #if GTEST_HAS_EXCEPTIONS # include #endif #include #include #include #include #include #include #include #include #include #include // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file defines the Message class. // // IMPORTANT NOTE: Due to limitation of the C++ language, we have to // leave some internal implementation details in this header file. // They are clearly marked by comments like this: // // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. // // Such code is NOT meant to be used by a user directly, and is subject // to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user // program! // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ #define GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ #include #include GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) // Ensures that there is at least one operator<< in the global namespace. // See Message& operator<<(...) below for why. void operator<<(const testing::internal::Secret&, int); namespace testing { // The Message class works like an ostream repeater. // // Typical usage: // // 1. You stream a bunch of values to a Message object. // It will remember the text in a stringstream. // 2. Then you stream the Message object to an ostream. // This causes the text in the Message to be streamed // to the ostream. // // For example; // // testing::Message foo; // foo << 1 << " != " << 2; // std::cout << foo; // // will print "1 != 2". // // Message is not intended to be inherited from. In particular, its // destructor is not virtual. // // Note that stringstream behaves differently in gcc and in MSVC. You // can stream a NULL char pointer to it in the former, but not in the // latter (it causes an access violation if you do). The Message // class hides this difference by treating a NULL char pointer as // "(null)". class GTEST_API_ Message { private: // The type of basic IO manipulators (endl, ends, and flush) for // narrow streams. typedef std::ostream& (*BasicNarrowIoManip)(std::ostream&); public: // Constructs an empty Message. Message(); // Copy constructor. Message(const Message& msg) : ss_(new ::std::stringstream) { // NOLINT *ss_ << msg.GetString(); } // Constructs a Message from a C-string. explicit Message(const char* str) : ss_(new ::std::stringstream) { *ss_ << str; } // Streams a non-pointer value to this object. template inline Message& operator <<(const T& val) { // Some libraries overload << for STL containers. These // overloads are defined in the global namespace instead of ::std. // // C++'s symbol lookup rule (i.e. Koenig lookup) says that these // overloads are visible in either the std namespace or the global // namespace, but not other namespaces, including the testing // namespace which Google Test's Message class is in. // // To allow STL containers (and other types that has a << operator // defined in the global namespace) to be used in Google Test // assertions, testing::Message must access the custom << operator // from the global namespace. With this using declaration, // overloads of << defined in the global namespace and those // visible via Koenig lookup are both exposed in this function. using ::operator <<; *ss_ << val; return *this; } // Streams a pointer value to this object. // // This function is an overload of the previous one. When you // stream a pointer to a Message, this definition will be used as it // is more specialized. (The C++ Standard, section // [temp.func.order].) If you stream a non-pointer, then the // previous definition will be used. // // The reason for this overload is that streaming a NULL pointer to // ostream is undefined behavior. Depending on the compiler, you // may get "0", "(nil)", "(null)", or an access violation. To // ensure consistent result across compilers, we always treat NULL // as "(null)". template inline Message& operator <<(T* const& pointer) { // NOLINT if (pointer == nullptr) { *ss_ << "(null)"; } else { *ss_ << pointer; } return *this; } // Since the basic IO manipulators are overloaded for both narrow // and wide streams, we have to provide this specialized definition // of operator <<, even though its body is the same as the // templatized version above. Without this definition, streaming // endl or other basic IO manipulators to Message will confuse the // compiler. Message& operator <<(BasicNarrowIoManip val) { *ss_ << val; return *this; } // Instead of 1/0, we want to see true/false for bool values. Message& operator <<(bool b) { return *this << (b ? "true" : "false"); } // These two overloads allow streaming a wide C string to a Message // using the UTF-8 encoding. Message& operator <<(const wchar_t* wide_c_str); Message& operator <<(wchar_t* wide_c_str); #if GTEST_HAS_STD_WSTRING // Converts the given wide string to a narrow string using the UTF-8 // encoding, and streams the result to this Message object. Message& operator <<(const ::std::wstring& wstr); #endif // GTEST_HAS_STD_WSTRING // Gets the text streamed to this object so far as an std::string. // Each '\0' character in the buffer is replaced with "\\0". // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. std::string GetString() const; private: // We'll hold the text streamed to this object here. const std::unique_ptr< ::std::stringstream> ss_; // We declare (but don't implement) this to prevent the compiler // from implementing the assignment operator. void operator=(const Message&); }; // Streams a Message to an ostream. inline std::ostream& operator <<(std::ostream& os, const Message& sb) { return os << sb.GetString(); } namespace internal { // Converts a streamable value to an std::string. A NULL pointer is // converted to "(null)". When the input value is a ::string, // ::std::string, ::wstring, or ::std::wstring object, each NUL // character in it is replaced with "\\0". template std::string StreamableToString(const T& streamable) { return (Message() << streamable).GetString(); } } // namespace internal } // namespace testing GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 #endif // GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ // Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Google Test filepath utilities // // This header file declares classes and functions used internally by // Google Test. They are subject to change without notice. // // This file is #included in gtest/internal/gtest-internal.h. // Do not include this header file separately! // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file declares the String class and functions used internally by // Google Test. They are subject to change without notice. They should not used // by code external to Google Test. // // This header file is #included by gtest-internal.h. // It should not be #included by other files. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ #ifdef __BORLANDC__ // string.h is not guaranteed to provide strcpy on C++ Builder. # include #endif #include #include namespace testing { namespace internal { // String - an abstract class holding static string utilities. class GTEST_API_ String { public: // Static utility methods // Clones a 0-terminated C string, allocating memory using new. The // caller is responsible for deleting the return value using // delete[]. Returns the cloned string, or NULL if the input is // NULL. // // This is different from strdup() in string.h, which allocates // memory using malloc(). static const char* CloneCString(const char* c_str); #if GTEST_OS_WINDOWS_MOBILE // Windows CE does not have the 'ANSI' versions of Win32 APIs. To be // able to pass strings to Win32 APIs on CE we need to convert them // to 'Unicode', UTF-16. // Creates a UTF-16 wide string from the given ANSI string, allocating // memory using new. The caller is responsible for deleting the return // value using delete[]. Returns the wide string, or NULL if the // input is NULL. // // The wide string is created using the ANSI codepage (CP_ACP) to // match the behaviour of the ANSI versions of Win32 calls and the // C runtime. static LPCWSTR AnsiToUtf16(const char* c_str); // Creates an ANSI string from the given wide string, allocating // memory using new. The caller is responsible for deleting the return // value using delete[]. Returns the ANSI string, or NULL if the // input is NULL. // // The returned string is created using the ANSI codepage (CP_ACP) to // match the behaviour of the ANSI versions of Win32 calls and the // C runtime. static const char* Utf16ToAnsi(LPCWSTR utf16_str); #endif // Compares two C strings. Returns true if and only if they have the same // content. // // Unlike strcmp(), this function can handle NULL argument(s). A // NULL C string is considered different to any non-NULL C string, // including the empty string. static bool CStringEquals(const char* lhs, const char* rhs); // Converts a wide C string to a String using the UTF-8 encoding. // NULL will be converted to "(null)". If an error occurred during // the conversion, "(failed to convert from wide string)" is // returned. static std::string ShowWideCString(const wchar_t* wide_c_str); // Compares two wide C strings. Returns true if and only if they have the // same content. // // Unlike wcscmp(), this function can handle NULL argument(s). A // NULL C string is considered different to any non-NULL C string, // including the empty string. static bool WideCStringEquals(const wchar_t* lhs, const wchar_t* rhs); // Compares two C strings, ignoring case. Returns true if and only if // they have the same content. // // Unlike strcasecmp(), this function can handle NULL argument(s). // A NULL C string is considered different to any non-NULL C string, // including the empty string. static bool CaseInsensitiveCStringEquals(const char* lhs, const char* rhs); // Compares two wide C strings, ignoring case. Returns true if and only if // they have the same content. // // Unlike wcscasecmp(), this function can handle NULL argument(s). // A NULL C string is considered different to any non-NULL wide C string, // including the empty string. // NB: The implementations on different platforms slightly differ. // On windows, this method uses _wcsicmp which compares according to LC_CTYPE // environment variable. On GNU platform this method uses wcscasecmp // which compares according to LC_CTYPE category of the current locale. // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the // current locale. static bool CaseInsensitiveWideCStringEquals(const wchar_t* lhs, const wchar_t* rhs); // Returns true if and only if the given string ends with the given suffix, // ignoring case. Any string is considered to end with an empty suffix. static bool EndsWithCaseInsensitive( const std::string& str, const std::string& suffix); // Formats an int value as "%02d". static std::string FormatIntWidth2(int value); // "%02d" for width == 2 // Formats an int value as "%X". static std::string FormatHexInt(int value); // Formats an int value as "%X". static std::string FormatHexUInt32(UInt32 value); // Formats a byte as "%02X". static std::string FormatByte(unsigned char value); private: String(); // Not meant to be instantiated. }; // class String // Gets the content of the stringstream's buffer as an std::string. Each '\0' // character in the buffer is replaced with "\\0". GTEST_API_ std::string StringStreamToString(::std::stringstream* stream); } // namespace internal } // namespace testing #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) namespace testing { namespace internal { // FilePath - a class for file and directory pathname manipulation which // handles platform-specific conventions (like the pathname separator). // Used for helper functions for naming files in a directory for xml output. // Except for Set methods, all methods are const or static, which provides an // "immutable value object" -- useful for peace of mind. // A FilePath with a value ending in a path separator ("like/this/") represents // a directory, otherwise it is assumed to represent a file. In either case, // it may or may not represent an actual file or directory in the file system. // Names are NOT checked for syntax correctness -- no checking for illegal // characters, malformed paths, etc. class GTEST_API_ FilePath { public: FilePath() : pathname_("") { } FilePath(const FilePath& rhs) : pathname_(rhs.pathname_) { } explicit FilePath(const std::string& pathname) : pathname_(pathname) { Normalize(); } FilePath& operator=(const FilePath& rhs) { Set(rhs); return *this; } void Set(const FilePath& rhs) { pathname_ = rhs.pathname_; } const std::string& string() const { return pathname_; } const char* c_str() const { return pathname_.c_str(); } // Returns the current working directory, or "" if unsuccessful. static FilePath GetCurrentDir(); // Given directory = "dir", base_name = "test", number = 0, // extension = "xml", returns "dir/test.xml". If number is greater // than zero (e.g., 12), returns "dir/test_12.xml". // On Windows platform, uses \ as the separator rather than /. static FilePath MakeFileName(const FilePath& directory, const FilePath& base_name, int number, const char* extension); // Given directory = "dir", relative_path = "test.xml", // returns "dir/test.xml". // On Windows, uses \ as the separator rather than /. static FilePath ConcatPaths(const FilePath& directory, const FilePath& relative_path); // Returns a pathname for a file that does not currently exist. The pathname // will be directory/base_name.extension or // directory/base_name_.extension if directory/base_name.extension // already exists. The number will be incremented until a pathname is found // that does not already exist. // Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. // There could be a race condition if two or more processes are calling this // function at the same time -- they could both pick the same filename. static FilePath GenerateUniqueFileName(const FilePath& directory, const FilePath& base_name, const char* extension); // Returns true if and only if the path is "". bool IsEmpty() const { return pathname_.empty(); } // If input name has a trailing separator character, removes it and returns // the name, otherwise return the name string unmodified. // On Windows platform, uses \ as the separator, other platforms use /. FilePath RemoveTrailingPathSeparator() const; // Returns a copy of the FilePath with the directory part removed. // Example: FilePath("path/to/file").RemoveDirectoryName() returns // FilePath("file"). If there is no directory part ("just_a_file"), it returns // the FilePath unmodified. If there is no file part ("just_a_dir/") it // returns an empty FilePath (""). // On Windows platform, '\' is the path separator, otherwise it is '/'. FilePath RemoveDirectoryName() const; // RemoveFileName returns the directory path with the filename removed. // Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". // If the FilePath is "a_file" or "/a_file", RemoveFileName returns // FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does // not have a file, like "just/a/dir/", it returns the FilePath unmodified. // On Windows platform, '\' is the path separator, otherwise it is '/'. FilePath RemoveFileName() const; // Returns a copy of the FilePath with the case-insensitive extension removed. // Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns // FilePath("dir/file"). If a case-insensitive extension is not // found, returns a copy of the original FilePath. FilePath RemoveExtension(const char* extension) const; // Creates directories so that path exists. Returns true if successful or if // the directories already exist; returns false if unable to create // directories for any reason. Will also return false if the FilePath does // not represent a directory (that is, it doesn't end with a path separator). bool CreateDirectoriesRecursively() const; // Create the directory so that path exists. Returns true if successful or // if the directory already exists; returns false if unable to create the // directory for any reason, including if the parent directory does not // exist. Not named "CreateDirectory" because that's a macro on Windows. bool CreateFolder() const; // Returns true if FilePath describes something in the file-system, // either a file, directory, or whatever, and that something exists. bool FileOrDirectoryExists() const; // Returns true if pathname describes a directory in the file-system // that exists. bool DirectoryExists() const; // Returns true if FilePath ends with a path separator, which indicates that // it is intended to represent a directory. Returns false otherwise. // This does NOT check that a directory (or file) actually exists. bool IsDirectory() const; // Returns true if pathname describes a root directory. (Windows has one // root directory per disk drive.) bool IsRootDirectory() const; // Returns true if pathname describes an absolute path. bool IsAbsolutePath() const; private: // Replaces multiple consecutive separators with a single separator. // For example, "bar///foo" becomes "bar/foo". Does not eliminate other // redundancies that might be in a pathname involving "." or "..". // // A pathname with multiple consecutive separators may occur either through // user error or as a result of some scripts or APIs that generate a pathname // with a trailing separator. On other platforms the same API or script // may NOT generate a pathname with a trailing "/". Then elsewhere that // pathname may have another "/" and pathname components added to it, // without checking for the separator already being there. // The script language and operating system may allow paths like "foo//bar" // but some of the functions in FilePath will not handle that correctly. In // particular, RemoveTrailingPathSeparator() only removes one separator, and // it is called in CreateDirectoriesRecursively() assuming that it will change // a pathname from directory syntax (trailing separator) to filename syntax. // // On Windows this method also replaces the alternate path separator '/' with // the primary path separator '\\', so that for example "bar\\/\\foo" becomes // "bar\\foo". void Normalize(); // Returns a pointer to the last occurence of a valid path separator in // the FilePath. On Windows, for example, both '/' and '\' are valid path // separators. Returns NULL if no path separator was found. const char* FindLastPathSeparator() const; std::string pathname_; }; // class FilePath } // namespace internal } // namespace testing GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ // This file was GENERATED by command: // pump.py gtest-type-util.h.pump // DO NOT EDIT BY HAND!!! // Copyright 2008 Google Inc. // All Rights Reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Type utilities needed for implementing typed and type-parameterized // tests. This file is generated by a SCRIPT. DO NOT EDIT BY HAND! // // Currently we support at most 50 types in a list, and at most 50 // type-parameterized tests in one type-parameterized test suite. // Please contact googletestframework@googlegroups.com if you need // more. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ // #ifdef __GNUC__ is too general here. It is possible to use gcc without using // libstdc++ (which is where cxxabi.h comes from). # if GTEST_HAS_CXXABI_H_ # include # elif defined(__HP_aCC) # include # endif // GTEST_HASH_CXXABI_H_ namespace testing { namespace internal { // Canonicalizes a given name with respect to the Standard C++ Library. // This handles removing the inline namespace within `std` that is // used by various standard libraries (e.g., `std::__1`). Names outside // of namespace std are returned unmodified. inline std::string CanonicalizeForStdLibVersioning(std::string s) { static const char prefix[] = "std::__"; if (s.compare(0, strlen(prefix), prefix) == 0) { std::string::size_type end = s.find("::", strlen(prefix)); if (end != s.npos) { // Erase everything between the initial `std` and the second `::`. s.erase(strlen("std"), end - strlen("std")); } } return s; } // GetTypeName() returns a human-readable name of type T. // NB: This function is also used in Google Mock, so don't move it inside of // the typed-test-only section below. template std::string GetTypeName() { # if GTEST_HAS_RTTI const char* const name = typeid(T).name(); # if GTEST_HAS_CXXABI_H_ || defined(__HP_aCC) int status = 0; // gcc's implementation of typeid(T).name() mangles the type name, // so we have to demangle it. # if GTEST_HAS_CXXABI_H_ using abi::__cxa_demangle; # endif // GTEST_HAS_CXXABI_H_ char* const readable_name = __cxa_demangle(name, nullptr, nullptr, &status); const std::string name_str(status == 0 ? readable_name : name); free(readable_name); return CanonicalizeForStdLibVersioning(name_str); # else return name; # endif // GTEST_HAS_CXXABI_H_ || __HP_aCC # else return ""; # endif // GTEST_HAS_RTTI } #if GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P // A unique type used as the default value for the arguments of class // template Types. This allows us to simulate variadic templates // (e.g. Types, Type, and etc), which C++ doesn't // support directly. struct None {}; // The following family of struct and struct templates are used to // represent type lists. In particular, TypesN // represents a type list with N types (T1, T2, ..., and TN) in it. // Except for Types0, every struct in the family has two member types: // Head for the first type in the list, and Tail for the rest of the // list. // The empty type list. struct Types0 {}; // Type lists of length 1, 2, 3, and so on. template struct Types1 { typedef T1 Head; typedef Types0 Tail; }; template struct Types2 { typedef T1 Head; typedef Types1 Tail; }; template struct Types3 { typedef T1 Head; typedef Types2 Tail; }; template struct Types4 { typedef T1 Head; typedef Types3 Tail; }; template struct Types5 { typedef T1 Head; typedef Types4 Tail; }; template struct Types6 { typedef T1 Head; typedef Types5 Tail; }; template struct Types7 { typedef T1 Head; typedef Types6 Tail; }; template struct Types8 { typedef T1 Head; typedef Types7 Tail; }; template struct Types9 { typedef T1 Head; typedef Types8 Tail; }; template struct Types10 { typedef T1 Head; typedef Types9 Tail; }; template struct Types11 { typedef T1 Head; typedef Types10 Tail; }; template struct Types12 { typedef T1 Head; typedef Types11 Tail; }; template struct Types13 { typedef T1 Head; typedef Types12 Tail; }; template struct Types14 { typedef T1 Head; typedef Types13 Tail; }; template struct Types15 { typedef T1 Head; typedef Types14 Tail; }; template struct Types16 { typedef T1 Head; typedef Types15 Tail; }; template struct Types17 { typedef T1 Head; typedef Types16 Tail; }; template struct Types18 { typedef T1 Head; typedef Types17 Tail; }; template struct Types19 { typedef T1 Head; typedef Types18 Tail; }; template struct Types20 { typedef T1 Head; typedef Types19 Tail; }; template struct Types21 { typedef T1 Head; typedef Types20 Tail; }; template struct Types22 { typedef T1 Head; typedef Types21 Tail; }; template struct Types23 { typedef T1 Head; typedef Types22 Tail; }; template struct Types24 { typedef T1 Head; typedef Types23 Tail; }; template struct Types25 { typedef T1 Head; typedef Types24 Tail; }; template struct Types26 { typedef T1 Head; typedef Types25 Tail; }; template struct Types27 { typedef T1 Head; typedef Types26 Tail; }; template struct Types28 { typedef T1 Head; typedef Types27 Tail; }; template struct Types29 { typedef T1 Head; typedef Types28 Tail; }; template struct Types30 { typedef T1 Head; typedef Types29 Tail; }; template struct Types31 { typedef T1 Head; typedef Types30 Tail; }; template struct Types32 { typedef T1 Head; typedef Types31 Tail; }; template struct Types33 { typedef T1 Head; typedef Types32 Tail; }; template struct Types34 { typedef T1 Head; typedef Types33 Tail; }; template struct Types35 { typedef T1 Head; typedef Types34 Tail; }; template struct Types36 { typedef T1 Head; typedef Types35 Tail; }; template struct Types37 { typedef T1 Head; typedef Types36 Tail; }; template struct Types38 { typedef T1 Head; typedef Types37 Tail; }; template struct Types39 { typedef T1 Head; typedef Types38 Tail; }; template struct Types40 { typedef T1 Head; typedef Types39 Tail; }; template struct Types41 { typedef T1 Head; typedef Types40 Tail; }; template struct Types42 { typedef T1 Head; typedef Types41 Tail; }; template struct Types43 { typedef T1 Head; typedef Types42 Tail; }; template struct Types44 { typedef T1 Head; typedef Types43 Tail; }; template struct Types45 { typedef T1 Head; typedef Types44 Tail; }; template struct Types46 { typedef T1 Head; typedef Types45 Tail; }; template struct Types47 { typedef T1 Head; typedef Types46 Tail; }; template struct Types48 { typedef T1 Head; typedef Types47 Tail; }; template struct Types49 { typedef T1 Head; typedef Types48 Tail; }; template struct Types50 { typedef T1 Head; typedef Types49 Tail; }; } // namespace internal // We don't want to require the users to write TypesN<...> directly, // as that would require them to count the length. Types<...> is much // easier to write, but generates horrible messages when there is a // compiler error, as gcc insists on printing out each template // argument, even if it has the default value (this means Types // will appear as Types in the compiler // errors). // // Our solution is to combine the best part of the two approaches: a // user would write Types, and Google Test will translate // that to TypesN internally to make error messages // readable. The translation is done by the 'type' member of the // Types template. template struct Types { typedef internal::Types50 type; }; template <> struct Types { typedef internal::Types0 type; }; template struct Types { typedef internal::Types1 type; }; template struct Types { typedef internal::Types2 type; }; template struct Types { typedef internal::Types3 type; }; template struct Types { typedef internal::Types4 type; }; template struct Types { typedef internal::Types5 type; }; template struct Types { typedef internal::Types6 type; }; template struct Types { typedef internal::Types7 type; }; template struct Types { typedef internal::Types8 type; }; template struct Types { typedef internal::Types9 type; }; template struct Types { typedef internal::Types10 type; }; template struct Types { typedef internal::Types11 type; }; template struct Types { typedef internal::Types12 type; }; template struct Types { typedef internal::Types13 type; }; template struct Types { typedef internal::Types14 type; }; template struct Types { typedef internal::Types15 type; }; template struct Types { typedef internal::Types16 type; }; template struct Types { typedef internal::Types17 type; }; template struct Types { typedef internal::Types18 type; }; template struct Types { typedef internal::Types19 type; }; template struct Types { typedef internal::Types20 type; }; template struct Types { typedef internal::Types21 type; }; template struct Types { typedef internal::Types22 type; }; template struct Types { typedef internal::Types23 type; }; template struct Types { typedef internal::Types24 type; }; template struct Types { typedef internal::Types25 type; }; template struct Types { typedef internal::Types26 type; }; template struct Types { typedef internal::Types27 type; }; template struct Types { typedef internal::Types28 type; }; template struct Types { typedef internal::Types29 type; }; template struct Types { typedef internal::Types30 type; }; template struct Types { typedef internal::Types31 type; }; template struct Types { typedef internal::Types32 type; }; template struct Types { typedef internal::Types33 type; }; template struct Types { typedef internal::Types34 type; }; template struct Types { typedef internal::Types35 type; }; template struct Types { typedef internal::Types36 type; }; template struct Types { typedef internal::Types37 type; }; template struct Types { typedef internal::Types38 type; }; template struct Types { typedef internal::Types39 type; }; template struct Types { typedef internal::Types40 type; }; template struct Types { typedef internal::Types41 type; }; template struct Types { typedef internal::Types42 type; }; template struct Types { typedef internal::Types43 type; }; template struct Types { typedef internal::Types44 type; }; template struct Types { typedef internal::Types45 type; }; template struct Types { typedef internal::Types46 type; }; template struct Types { typedef internal::Types47 type; }; template struct Types { typedef internal::Types48 type; }; template struct Types { typedef internal::Types49 type; }; namespace internal { # define GTEST_TEMPLATE_ template class // The template "selector" struct TemplateSel is used to // represent Tmpl, which must be a class template with one type // parameter, as a type. TemplateSel::Bind::type is defined // as the type Tmpl. This allows us to actually instantiate the // template "selected" by TemplateSel. // // This trick is necessary for simulating typedef for class templates, // which C++ doesn't support directly. template struct TemplateSel { template struct Bind { typedef Tmpl type; }; }; # define GTEST_BIND_(TmplSel, T) \ TmplSel::template Bind::type // A unique struct template used as the default value for the // arguments of class template Templates. This allows us to simulate // variadic templates (e.g. Templates, Templates, // and etc), which C++ doesn't support directly. template struct NoneT {}; // The following family of struct and struct templates are used to // represent template lists. In particular, TemplatesN represents a list of N templates (T1, T2, ..., and TN). Except // for Templates0, every struct in the family has two member types: // Head for the selector of the first template in the list, and Tail // for the rest of the list. // The empty template list. struct Templates0 {}; // Template lists of length 1, 2, 3, and so on. template struct Templates1 { typedef TemplateSel Head; typedef Templates0 Tail; }; template struct Templates2 { typedef TemplateSel Head; typedef Templates1 Tail; }; template struct Templates3 { typedef TemplateSel Head; typedef Templates2 Tail; }; template struct Templates4 { typedef TemplateSel Head; typedef Templates3 Tail; }; template struct Templates5 { typedef TemplateSel Head; typedef Templates4 Tail; }; template struct Templates6 { typedef TemplateSel Head; typedef Templates5 Tail; }; template struct Templates7 { typedef TemplateSel Head; typedef Templates6 Tail; }; template struct Templates8 { typedef TemplateSel Head; typedef Templates7 Tail; }; template struct Templates9 { typedef TemplateSel Head; typedef Templates8 Tail; }; template struct Templates10 { typedef TemplateSel Head; typedef Templates9 Tail; }; template struct Templates11 { typedef TemplateSel Head; typedef Templates10 Tail; }; template struct Templates12 { typedef TemplateSel Head; typedef Templates11 Tail; }; template struct Templates13 { typedef TemplateSel Head; typedef Templates12 Tail; }; template struct Templates14 { typedef TemplateSel Head; typedef Templates13 Tail; }; template struct Templates15 { typedef TemplateSel Head; typedef Templates14 Tail; }; template struct Templates16 { typedef TemplateSel Head; typedef Templates15 Tail; }; template struct Templates17 { typedef TemplateSel Head; typedef Templates16 Tail; }; template struct Templates18 { typedef TemplateSel Head; typedef Templates17 Tail; }; template struct Templates19 { typedef TemplateSel Head; typedef Templates18 Tail; }; template struct Templates20 { typedef TemplateSel Head; typedef Templates19 Tail; }; template struct Templates21 { typedef TemplateSel Head; typedef Templates20 Tail; }; template struct Templates22 { typedef TemplateSel Head; typedef Templates21 Tail; }; template struct Templates23 { typedef TemplateSel Head; typedef Templates22 Tail; }; template struct Templates24 { typedef TemplateSel Head; typedef Templates23 Tail; }; template struct Templates25 { typedef TemplateSel Head; typedef Templates24 Tail; }; template struct Templates26 { typedef TemplateSel Head; typedef Templates25 Tail; }; template struct Templates27 { typedef TemplateSel Head; typedef Templates26 Tail; }; template struct Templates28 { typedef TemplateSel Head; typedef Templates27 Tail; }; template struct Templates29 { typedef TemplateSel Head; typedef Templates28 Tail; }; template struct Templates30 { typedef TemplateSel Head; typedef Templates29 Tail; }; template struct Templates31 { typedef TemplateSel Head; typedef Templates30 Tail; }; template struct Templates32 { typedef TemplateSel Head; typedef Templates31 Tail; }; template struct Templates33 { typedef TemplateSel Head; typedef Templates32 Tail; }; template struct Templates34 { typedef TemplateSel Head; typedef Templates33 Tail; }; template struct Templates35 { typedef TemplateSel Head; typedef Templates34 Tail; }; template struct Templates36 { typedef TemplateSel Head; typedef Templates35 Tail; }; template struct Templates37 { typedef TemplateSel Head; typedef Templates36 Tail; }; template struct Templates38 { typedef TemplateSel Head; typedef Templates37 Tail; }; template struct Templates39 { typedef TemplateSel Head; typedef Templates38 Tail; }; template struct Templates40 { typedef TemplateSel Head; typedef Templates39 Tail; }; template struct Templates41 { typedef TemplateSel Head; typedef Templates40 Tail; }; template struct Templates42 { typedef TemplateSel Head; typedef Templates41 Tail; }; template struct Templates43 { typedef TemplateSel Head; typedef Templates42 Tail; }; template struct Templates44 { typedef TemplateSel Head; typedef Templates43 Tail; }; template struct Templates45 { typedef TemplateSel Head; typedef Templates44 Tail; }; template struct Templates46 { typedef TemplateSel Head; typedef Templates45 Tail; }; template struct Templates47 { typedef TemplateSel Head; typedef Templates46 Tail; }; template struct Templates48 { typedef TemplateSel Head; typedef Templates47 Tail; }; template struct Templates49 { typedef TemplateSel Head; typedef Templates48 Tail; }; template struct Templates50 { typedef TemplateSel Head; typedef Templates49 Tail; }; // We don't want to require the users to write TemplatesN<...> directly, // as that would require them to count the length. Templates<...> is much // easier to write, but generates horrible messages when there is a // compiler error, as gcc insists on printing out each template // argument, even if it has the default value (this means Templates // will appear as Templates in the compiler // errors). // // Our solution is to combine the best part of the two approaches: a // user would write Templates, and Google Test will translate // that to TemplatesN internally to make error messages // readable. The translation is done by the 'type' member of the // Templates template. template struct Templates { typedef Templates50 type; }; template <> struct Templates { typedef Templates0 type; }; template struct Templates { typedef Templates1 type; }; template struct Templates { typedef Templates2 type; }; template struct Templates { typedef Templates3 type; }; template struct Templates { typedef Templates4 type; }; template struct Templates { typedef Templates5 type; }; template struct Templates { typedef Templates6 type; }; template struct Templates { typedef Templates7 type; }; template struct Templates { typedef Templates8 type; }; template struct Templates { typedef Templates9 type; }; template struct Templates { typedef Templates10 type; }; template struct Templates { typedef Templates11 type; }; template struct Templates { typedef Templates12 type; }; template struct Templates { typedef Templates13 type; }; template struct Templates { typedef Templates14 type; }; template struct Templates { typedef Templates15 type; }; template struct Templates { typedef Templates16 type; }; template struct Templates { typedef Templates17 type; }; template struct Templates { typedef Templates18 type; }; template struct Templates { typedef Templates19 type; }; template struct Templates { typedef Templates20 type; }; template struct Templates { typedef Templates21 type; }; template struct Templates { typedef Templates22 type; }; template struct Templates { typedef Templates23 type; }; template struct Templates { typedef Templates24 type; }; template struct Templates { typedef Templates25 type; }; template struct Templates { typedef Templates26 type; }; template struct Templates { typedef Templates27 type; }; template struct Templates { typedef Templates28 type; }; template struct Templates { typedef Templates29 type; }; template struct Templates { typedef Templates30 type; }; template struct Templates { typedef Templates31 type; }; template struct Templates { typedef Templates32 type; }; template struct Templates { typedef Templates33 type; }; template struct Templates { typedef Templates34 type; }; template struct Templates { typedef Templates35 type; }; template struct Templates { typedef Templates36 type; }; template struct Templates { typedef Templates37 type; }; template struct Templates { typedef Templates38 type; }; template struct Templates { typedef Templates39 type; }; template struct Templates { typedef Templates40 type; }; template struct Templates { typedef Templates41 type; }; template struct Templates { typedef Templates42 type; }; template struct Templates { typedef Templates43 type; }; template struct Templates { typedef Templates44 type; }; template struct Templates { typedef Templates45 type; }; template struct Templates { typedef Templates46 type; }; template struct Templates { typedef Templates47 type; }; template struct Templates { typedef Templates48 type; }; template struct Templates { typedef Templates49 type; }; // The TypeList template makes it possible to use either a single type // or a Types<...> list in TYPED_TEST_SUITE() and // INSTANTIATE_TYPED_TEST_SUITE_P(). template struct TypeList { typedef Types1 type; }; template struct TypeList > { typedef typename Types::type type; }; #endif // GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P } // namespace internal } // namespace testing #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ // Due to C++ preprocessor weirdness, we need double indirection to // concatenate two tokens when one of them is __LINE__. Writing // // foo ## __LINE__ // // will result in the token foo__LINE__, instead of foo followed by // the current line number. For more details, see // http://www.parashift.com/c++-faq-lite/misc-technical-issues.html#faq-39.6 #define GTEST_CONCAT_TOKEN_(foo, bar) GTEST_CONCAT_TOKEN_IMPL_(foo, bar) #define GTEST_CONCAT_TOKEN_IMPL_(foo, bar) foo ## bar // Stringifies its argument. #define GTEST_STRINGIFY_(name) #name namespace proto2 { class Message; } namespace testing { // Forward declarations. class AssertionResult; // Result of an assertion. class Message; // Represents a failure message. class Test; // Represents a test. class TestInfo; // Information about a test. class TestPartResult; // Result of a test part. class UnitTest; // A collection of test suites. template ::std::string PrintToString(const T& value); namespace internal { struct TraceInfo; // Information about a trace point. class TestInfoImpl; // Opaque implementation of TestInfo class UnitTestImpl; // Opaque implementation of UnitTest // The text used in failure messages to indicate the start of the // stack trace. GTEST_API_ extern const char kStackTraceMarker[]; // An IgnoredValue object can be implicitly constructed from ANY value. class IgnoredValue { struct Sink {}; public: // This constructor template allows any value to be implicitly // converted to IgnoredValue. The object has no data member and // doesn't try to remember anything about the argument. We // deliberately omit the 'explicit' keyword in order to allow the // conversion to be implicit. // Disable the conversion if T already has a magical conversion operator. // Otherwise we get ambiguity. template ::value, int>::type = 0> IgnoredValue(const T& /* ignored */) {} // NOLINT(runtime/explicit) }; // Appends the user-supplied message to the Google-Test-generated message. GTEST_API_ std::string AppendUserMessage( const std::string& gtest_msg, const Message& user_msg); #if GTEST_HAS_EXCEPTIONS GTEST_DISABLE_MSC_WARNINGS_PUSH_(4275 \ /* an exported class was derived from a class that was not exported */) // This exception is thrown by (and only by) a failed Google Test // assertion when GTEST_FLAG(throw_on_failure) is true (if exceptions // are enabled). We derive it from std::runtime_error, which is for // errors presumably detectable only at run time. Since // std::runtime_error inherits from std::exception, many testing // frameworks know how to extract and print the message inside it. class GTEST_API_ GoogleTestFailureException : public ::std::runtime_error { public: explicit GoogleTestFailureException(const TestPartResult& failure); }; GTEST_DISABLE_MSC_WARNINGS_POP_() // 4275 #endif // GTEST_HAS_EXCEPTIONS namespace edit_distance { // Returns the optimal edits to go from 'left' to 'right'. // All edits cost the same, with replace having lower priority than // add/remove. // Simple implementation of the Wagner-Fischer algorithm. // See http://en.wikipedia.org/wiki/Wagner-Fischer_algorithm enum EditType { kMatch, kAdd, kRemove, kReplace }; GTEST_API_ std::vector CalculateOptimalEdits( const std::vector& left, const std::vector& right); // Same as above, but the input is represented as strings. GTEST_API_ std::vector CalculateOptimalEdits( const std::vector& left, const std::vector& right); // Create a diff of the input strings in Unified diff format. GTEST_API_ std::string CreateUnifiedDiff(const std::vector& left, const std::vector& right, size_t context = 2); } // namespace edit_distance // Calculate the diff between 'left' and 'right' and return it in unified diff // format. // If not null, stores in 'total_line_count' the total number of lines found // in left + right. GTEST_API_ std::string DiffStrings(const std::string& left, const std::string& right, size_t* total_line_count); // Constructs and returns the message for an equality assertion // (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. // // The first four parameters are the expressions used in the assertion // and their values, as strings. For example, for ASSERT_EQ(foo, bar) // where foo is 5 and bar is 6, we have: // // expected_expression: "foo" // actual_expression: "bar" // expected_value: "5" // actual_value: "6" // // The ignoring_case parameter is true if and only if the assertion is a // *_STRCASEEQ*. When it's true, the string " (ignoring case)" will // be inserted into the message. GTEST_API_ AssertionResult EqFailure(const char* expected_expression, const char* actual_expression, const std::string& expected_value, const std::string& actual_value, bool ignoring_case); // Constructs a failure message for Boolean assertions such as EXPECT_TRUE. GTEST_API_ std::string GetBoolAssertionFailureMessage( const AssertionResult& assertion_result, const char* expression_text, const char* actual_predicate_value, const char* expected_predicate_value); // This template class represents an IEEE floating-point number // (either single-precision or double-precision, depending on the // template parameters). // // The purpose of this class is to do more sophisticated number // comparison. (Due to round-off error, etc, it's very unlikely that // two floating-points will be equal exactly. Hence a naive // comparison by the == operation often doesn't work.) // // Format of IEEE floating-point: // // The most-significant bit being the leftmost, an IEEE // floating-point looks like // // sign_bit exponent_bits fraction_bits // // Here, sign_bit is a single bit that designates the sign of the // number. // // For float, there are 8 exponent bits and 23 fraction bits. // // For double, there are 11 exponent bits and 52 fraction bits. // // More details can be found at // http://en.wikipedia.org/wiki/IEEE_floating-point_standard. // // Template parameter: // // RawType: the raw floating-point type (either float or double) template class FloatingPoint { public: // Defines the unsigned integer type that has the same size as the // floating point number. typedef typename TypeWithSize::UInt Bits; // Constants. // # of bits in a number. static const size_t kBitCount = 8*sizeof(RawType); // # of fraction bits in a number. static const size_t kFractionBitCount = std::numeric_limits::digits - 1; // # of exponent bits in a number. static const size_t kExponentBitCount = kBitCount - 1 - kFractionBitCount; // The mask for the sign bit. static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); // The mask for the fraction bits. static const Bits kFractionBitMask = ~static_cast(0) >> (kExponentBitCount + 1); // The mask for the exponent bits. static const Bits kExponentBitMask = ~(kSignBitMask | kFractionBitMask); // How many ULP's (Units in the Last Place) we want to tolerate when // comparing two numbers. The larger the value, the more error we // allow. A 0 value means that two numbers must be exactly the same // to be considered equal. // // The maximum error of a single floating-point operation is 0.5 // units in the last place. On Intel CPU's, all floating-point // calculations are done with 80-bit precision, while double has 64 // bits. Therefore, 4 should be enough for ordinary use. // // See the following article for more details on ULP: // http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ static const size_t kMaxUlps = 4; // Constructs a FloatingPoint from a raw floating-point number. // // On an Intel CPU, passing a non-normalized NAN (Not a Number) // around may change its bits, although the new value is guaranteed // to be also a NAN. Therefore, don't expect this constructor to // preserve the bits in x when x is a NAN. explicit FloatingPoint(const RawType& x) { u_.value_ = x; } // Static methods // Reinterprets a bit pattern as a floating-point number. // // This function is needed to test the AlmostEquals() method. static RawType ReinterpretBits(const Bits bits) { FloatingPoint fp(0); fp.u_.bits_ = bits; return fp.u_.value_; } // Returns the floating-point number that represent positive infinity. static RawType Infinity() { return ReinterpretBits(kExponentBitMask); } // Returns the maximum representable finite floating-point number. static RawType Max(); // Non-static methods // Returns the bits that represents this number. const Bits &bits() const { return u_.bits_; } // Returns the exponent bits of this number. Bits exponent_bits() const { return kExponentBitMask & u_.bits_; } // Returns the fraction bits of this number. Bits fraction_bits() const { return kFractionBitMask & u_.bits_; } // Returns the sign bit of this number. Bits sign_bit() const { return kSignBitMask & u_.bits_; } // Returns true if and only if this is NAN (not a number). bool is_nan() const { // It's a NAN if the exponent bits are all ones and the fraction // bits are not entirely zeros. return (exponent_bits() == kExponentBitMask) && (fraction_bits() != 0); } // Returns true if and only if this number is at most kMaxUlps ULP's away // from rhs. In particular, this function: // // - returns false if either number is (or both are) NAN. // - treats really large numbers as almost equal to infinity. // - thinks +0.0 and -0.0 are 0 DLP's apart. bool AlmostEquals(const FloatingPoint& rhs) const { // The IEEE standard says that any comparison operation involving // a NAN must return false. if (is_nan() || rhs.is_nan()) return false; return DistanceBetweenSignAndMagnitudeNumbers(u_.bits_, rhs.u_.bits_) <= kMaxUlps; } private: // The data type used to store the actual floating-point number. union FloatingPointUnion { RawType value_; // The raw floating-point number. Bits bits_; // The bits that represent the number. }; // Converts an integer from the sign-and-magnitude representation to // the biased representation. More precisely, let N be 2 to the // power of (kBitCount - 1), an integer x is represented by the // unsigned number x + N. // // For instance, // // -N + 1 (the most negative number representable using // sign-and-magnitude) is represented by 1; // 0 is represented by N; and // N - 1 (the biggest number representable using // sign-and-magnitude) is represented by 2N - 1. // // Read http://en.wikipedia.org/wiki/Signed_number_representations // for more details on signed number representations. static Bits SignAndMagnitudeToBiased(const Bits &sam) { if (kSignBitMask & sam) { // sam represents a negative number. return ~sam + 1; } else { // sam represents a positive number. return kSignBitMask | sam; } } // Given two numbers in the sign-and-magnitude representation, // returns the distance between them as an unsigned number. static Bits DistanceBetweenSignAndMagnitudeNumbers(const Bits &sam1, const Bits &sam2) { const Bits biased1 = SignAndMagnitudeToBiased(sam1); const Bits biased2 = SignAndMagnitudeToBiased(sam2); return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); } FloatingPointUnion u_; }; // We cannot use std::numeric_limits::max() as it clashes with the max() // macro defined by . template <> inline float FloatingPoint::Max() { return FLT_MAX; } template <> inline double FloatingPoint::Max() { return DBL_MAX; } // Typedefs the instances of the FloatingPoint template class that we // care to use. typedef FloatingPoint Float; typedef FloatingPoint Double; // In order to catch the mistake of putting tests that use different // test fixture classes in the same test suite, we need to assign // unique IDs to fixture classes and compare them. The TypeId type is // used to hold such IDs. The user should treat TypeId as an opaque // type: the only operation allowed on TypeId values is to compare // them for equality using the == operator. typedef const void* TypeId; template class TypeIdHelper { public: // dummy_ must not have a const type. Otherwise an overly eager // compiler (e.g. MSVC 7.1 & 8.0) may try to merge // TypeIdHelper::dummy_ for different Ts as an "optimization". static bool dummy_; }; template bool TypeIdHelper::dummy_ = false; // GetTypeId() returns the ID of type T. Different values will be // returned for different types. Calling the function twice with the // same type argument is guaranteed to return the same ID. template TypeId GetTypeId() { // The compiler is required to allocate a different // TypeIdHelper::dummy_ variable for each T used to instantiate // the template. Therefore, the address of dummy_ is guaranteed to // be unique. return &(TypeIdHelper::dummy_); } // Returns the type ID of ::testing::Test. Always call this instead // of GetTypeId< ::testing::Test>() to get the type ID of // ::testing::Test, as the latter may give the wrong result due to a // suspected linker bug when compiling Google Test as a Mac OS X // framework. GTEST_API_ TypeId GetTestTypeId(); // Defines the abstract factory interface that creates instances // of a Test object. class TestFactoryBase { public: virtual ~TestFactoryBase() {} // Creates a test instance to run. The instance is both created and destroyed // within TestInfoImpl::Run() virtual Test* CreateTest() = 0; protected: TestFactoryBase() {} private: GTEST_DISALLOW_COPY_AND_ASSIGN_(TestFactoryBase); }; // This class provides implementation of TeastFactoryBase interface. // It is used in TEST and TEST_F macros. template class TestFactoryImpl : public TestFactoryBase { public: Test* CreateTest() override { return new TestClass; } }; #if GTEST_OS_WINDOWS // Predicate-formatters for implementing the HRESULT checking macros // {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED} // We pass a long instead of HRESULT to avoid causing an // include dependency for the HRESULT type. GTEST_API_ AssertionResult IsHRESULTSuccess(const char* expr, long hr); // NOLINT GTEST_API_ AssertionResult IsHRESULTFailure(const char* expr, long hr); // NOLINT #endif // GTEST_OS_WINDOWS // Types of SetUpTestSuite() and TearDownTestSuite() functions. using SetUpTestSuiteFunc = void (*)(); using TearDownTestSuiteFunc = void (*)(); struct CodeLocation { CodeLocation(const std::string& a_file, int a_line) : file(a_file), line(a_line) {} std::string file; int line; }; // Helper to identify which setup function for TestCase / TestSuite to call. // Only one function is allowed, either TestCase or TestSute but not both. // Utility functions to help SuiteApiResolver using SetUpTearDownSuiteFuncType = void (*)(); inline SetUpTearDownSuiteFuncType GetNotDefaultOrNull( SetUpTearDownSuiteFuncType a, SetUpTearDownSuiteFuncType def) { return a == def ? nullptr : a; } template // Note that SuiteApiResolver inherits from T because // SetUpTestSuite()/TearDownTestSuite() could be protected. Ths way // SuiteApiResolver can access them. struct SuiteApiResolver : T { // testing::Test is only forward declared at this point. So we make it a // dependend class for the compiler to be OK with it. using Test = typename std::conditional::type; static SetUpTearDownSuiteFuncType GetSetUpCaseOrSuite(const char* filename, int line_num) { SetUpTearDownSuiteFuncType test_case_fp = GetNotDefaultOrNull(&T::SetUpTestCase, &Test::SetUpTestCase); SetUpTearDownSuiteFuncType test_suite_fp = GetNotDefaultOrNull(&T::SetUpTestSuite, &Test::SetUpTestSuite); GTEST_CHECK_(!test_case_fp || !test_suite_fp) << "Test can not provide both SetUpTestSuite and SetUpTestCase, please " "make sure there is only one present at " << filename << ":" << line_num; return test_case_fp != nullptr ? test_case_fp : test_suite_fp; } static SetUpTearDownSuiteFuncType GetTearDownCaseOrSuite(const char* filename, int line_num) { SetUpTearDownSuiteFuncType test_case_fp = GetNotDefaultOrNull(&T::TearDownTestCase, &Test::TearDownTestCase); SetUpTearDownSuiteFuncType test_suite_fp = GetNotDefaultOrNull(&T::TearDownTestSuite, &Test::TearDownTestSuite); GTEST_CHECK_(!test_case_fp || !test_suite_fp) << "Test can not provide both TearDownTestSuite and TearDownTestCase," " please make sure there is only one present at" << filename << ":" << line_num; return test_case_fp != nullptr ? test_case_fp : test_suite_fp; } }; // Creates a new TestInfo object and registers it with Google Test; // returns the created object. // // Arguments: // // test_suite_name: name of the test suite // name: name of the test // type_param the name of the test's type parameter, or NULL if // this is not a typed or a type-parameterized test. // value_param text representation of the test's value parameter, // or NULL if this is not a type-parameterized test. // code_location: code location where the test is defined // fixture_class_id: ID of the test fixture class // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite // factory: pointer to the factory that creates a test object. // The newly created TestInfo instance will assume // ownership of the factory object. GTEST_API_ TestInfo* MakeAndRegisterTestInfo( const char* test_suite_name, const char* name, const char* type_param, const char* value_param, CodeLocation code_location, TypeId fixture_class_id, SetUpTestSuiteFunc set_up_tc, TearDownTestSuiteFunc tear_down_tc, TestFactoryBase* factory); // If *pstr starts with the given prefix, modifies *pstr to be right // past the prefix and returns true; otherwise leaves *pstr unchanged // and returns false. None of pstr, *pstr, and prefix can be NULL. GTEST_API_ bool SkipPrefix(const char* prefix, const char** pstr); #if GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) // State of the definition of a type-parameterized test suite. class GTEST_API_ TypedTestSuitePState { public: TypedTestSuitePState() : registered_(false) {} // Adds the given test name to defined_test_names_ and return true // if the test suite hasn't been registered; otherwise aborts the // program. bool AddTestName(const char* file, int line, const char* case_name, const char* test_name) { if (registered_) { fprintf(stderr, "%s Test %s must be defined before " "REGISTER_TYPED_TEST_SUITE_P(%s, ...).\n", FormatFileLocation(file, line).c_str(), test_name, case_name); fflush(stderr); posix::Abort(); } registered_tests_.insert( ::std::make_pair(test_name, CodeLocation(file, line))); return true; } bool TestExists(const std::string& test_name) const { return registered_tests_.count(test_name) > 0; } const CodeLocation& GetCodeLocation(const std::string& test_name) const { RegisteredTestsMap::const_iterator it = registered_tests_.find(test_name); GTEST_CHECK_(it != registered_tests_.end()); return it->second; } // Verifies that registered_tests match the test names in // defined_test_names_; returns registered_tests if successful, or // aborts the program otherwise. const char* VerifyRegisteredTestNames( const char* file, int line, const char* registered_tests); private: typedef ::std::map RegisteredTestsMap; bool registered_; RegisteredTestsMap registered_tests_; }; // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ using TypedTestCasePState = TypedTestSuitePState; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 // Skips to the first non-space char after the first comma in 'str'; // returns NULL if no comma is found in 'str'. inline const char* SkipComma(const char* str) { const char* comma = strchr(str, ','); if (comma == nullptr) { return nullptr; } while (IsSpace(*(++comma))) {} return comma; } // Returns the prefix of 'str' before the first comma in it; returns // the entire string if it contains no comma. inline std::string GetPrefixUntilComma(const char* str) { const char* comma = strchr(str, ','); return comma == nullptr ? str : std::string(str, comma); } // Splits a given string on a given delimiter, populating a given // vector with the fields. void SplitString(const ::std::string& str, char delimiter, ::std::vector< ::std::string>* dest); // The default argument to the template below for the case when the user does // not provide a name generator. struct DefaultNameGenerator { template static std::string GetName(int i) { return StreamableToString(i); } }; template struct NameGeneratorSelector { typedef Provided type; }; template void GenerateNamesRecursively(Types0, std::vector*, int) {} template void GenerateNamesRecursively(Types, std::vector* result, int i) { result->push_back(NameGenerator::template GetName(i)); GenerateNamesRecursively(typename Types::Tail(), result, i + 1); } template std::vector GenerateNames() { std::vector result; GenerateNamesRecursively(Types(), &result, 0); return result; } // TypeParameterizedTest::Register() // registers a list of type-parameterized tests with Google Test. The // return value is insignificant - we just need to return something // such that we can call this function in a namespace scope. // // Implementation note: The GTEST_TEMPLATE_ macro declares a template // template parameter. It's defined in gtest-type-util.h. template class TypeParameterizedTest { public: // 'index' is the index of the test in the type list 'Types' // specified in INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, TestSuite, // Types). Valid values for 'index' are [0, N - 1] where N is the // length of Types. static bool Register(const char* prefix, const CodeLocation& code_location, const char* case_name, const char* test_names, int index, const std::vector& type_names = GenerateNames()) { typedef typename Types::Head Type; typedef Fixture FixtureClass; typedef typename GTEST_BIND_(TestSel, Type) TestClass; // First, registers the first type-parameterized test in the type // list. MakeAndRegisterTestInfo( (std::string(prefix) + (prefix[0] == '\0' ? "" : "/") + case_name + "/" + type_names[static_cast(index)]) .c_str(), StripTrailingSpaces(GetPrefixUntilComma(test_names)).c_str(), GetTypeName().c_str(), nullptr, // No value parameter. code_location, GetTypeId(), SuiteApiResolver::GetSetUpCaseOrSuite( code_location.file.c_str(), code_location.line), SuiteApiResolver::GetTearDownCaseOrSuite( code_location.file.c_str(), code_location.line), new TestFactoryImpl); // Next, recurses (at compile time) with the tail of the type list. return TypeParameterizedTest::Register(prefix, code_location, case_name, test_names, index + 1, type_names); } }; // The base case for the compile time recursion. template class TypeParameterizedTest { public: static bool Register(const char* /*prefix*/, const CodeLocation&, const char* /*case_name*/, const char* /*test_names*/, int /*index*/, const std::vector& = std::vector() /*type_names*/) { return true; } }; // TypeParameterizedTestSuite::Register() // registers *all combinations* of 'Tests' and 'Types' with Google // Test. The return value is insignificant - we just need to return // something such that we can call this function in a namespace scope. template class TypeParameterizedTestSuite { public: static bool Register(const char* prefix, CodeLocation code_location, const TypedTestSuitePState* state, const char* case_name, const char* test_names, const std::vector& type_names = GenerateNames()) { std::string test_name = StripTrailingSpaces( GetPrefixUntilComma(test_names)); if (!state->TestExists(test_name)) { fprintf(stderr, "Failed to get code location for test %s.%s at %s.", case_name, test_name.c_str(), FormatFileLocation(code_location.file.c_str(), code_location.line).c_str()); fflush(stderr); posix::Abort(); } const CodeLocation& test_location = state->GetCodeLocation(test_name); typedef typename Tests::Head Head; // First, register the first test in 'Test' for each type in 'Types'. TypeParameterizedTest::Register( prefix, test_location, case_name, test_names, 0, type_names); // Next, recurses (at compile time) with the tail of the test list. return TypeParameterizedTestSuite::Register(prefix, code_location, state, case_name, SkipComma(test_names), type_names); } }; // The base case for the compile time recursion. template class TypeParameterizedTestSuite { public: static bool Register(const char* /*prefix*/, const CodeLocation&, const TypedTestSuitePState* /*state*/, const char* /*case_name*/, const char* /*test_names*/, const std::vector& = std::vector() /*type_names*/) { return true; } }; #endif // GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P // Returns the current OS stack trace as an std::string. // // The maximum number of stack frames to be included is specified by // the gtest_stack_trace_depth flag. The skip_count parameter // specifies the number of top frames to be skipped, which doesn't // count against the number of frames to be included. // // For example, if Foo() calls Bar(), which in turn calls // GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in // the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. GTEST_API_ std::string GetCurrentOsStackTraceExceptTop( UnitTest* unit_test, int skip_count); // Helpers for suppressing warnings on unreachable code or constant // condition. // Always returns true. GTEST_API_ bool AlwaysTrue(); // Always returns false. inline bool AlwaysFalse() { return !AlwaysTrue(); } // Helper for suppressing false warning from Clang on a const char* // variable declared in a conditional expression always being NULL in // the else branch. struct GTEST_API_ ConstCharPtr { ConstCharPtr(const char* str) : value(str) {} operator bool() const { return true; } const char* value; }; // A simple Linear Congruential Generator for generating random // numbers with a uniform distribution. Unlike rand() and srand(), it // doesn't use global state (and therefore can't interfere with user // code). Unlike rand_r(), it's portable. An LCG isn't very random, // but it's good enough for our purposes. class GTEST_API_ Random { public: static const UInt32 kMaxRange = 1u << 31; explicit Random(UInt32 seed) : state_(seed) {} void Reseed(UInt32 seed) { state_ = seed; } // Generates a random number from [0, range). Crashes if 'range' is // 0 or greater than kMaxRange. UInt32 Generate(UInt32 range); private: UInt32 state_; GTEST_DISALLOW_COPY_AND_ASSIGN_(Random); }; // Turns const U&, U&, const U, and U all into U. #define GTEST_REMOVE_REFERENCE_AND_CONST_(T) \ typename std::remove_const::type>::type // IsAProtocolMessage::value is a compile-time bool constant that's // true if and only if T is type proto2::Message or a subclass of it. template struct IsAProtocolMessage : public bool_constant< std::is_convertible::value> {}; // When the compiler sees expression IsContainerTest(0), if C is an // STL-style container class, the first overload of IsContainerTest // will be viable (since both C::iterator* and C::const_iterator* are // valid types and NULL can be implicitly converted to them). It will // be picked over the second overload as 'int' is a perfect match for // the type of argument 0. If C::iterator or C::const_iterator is not // a valid type, the first overload is not viable, and the second // overload will be picked. Therefore, we can determine whether C is // a container class by checking the type of IsContainerTest(0). // The value of the expression is insignificant. // // In C++11 mode we check the existence of a const_iterator and that an // iterator is properly implemented for the container. // // For pre-C++11 that we look for both C::iterator and C::const_iterator. // The reason is that C++ injects the name of a class as a member of the // class itself (e.g. you can refer to class iterator as either // 'iterator' or 'iterator::iterator'). If we look for C::iterator // only, for example, we would mistakenly think that a class named // iterator is an STL container. // // Also note that the simpler approach of overloading // IsContainerTest(typename C::const_iterator*) and // IsContainerTest(...) doesn't work with Visual Age C++ and Sun C++. typedef int IsContainer; template ().begin()), class = decltype(::std::declval().end()), class = decltype(++::std::declval()), class = decltype(*::std::declval()), class = typename C::const_iterator> IsContainer IsContainerTest(int /* dummy */) { return 0; } typedef char IsNotContainer; template IsNotContainer IsContainerTest(long /* dummy */) { return '\0'; } // Trait to detect whether a type T is a hash table. // The heuristic used is that the type contains an inner type `hasher` and does // not contain an inner type `reverse_iterator`. // If the container is iterable in reverse, then order might actually matter. template struct IsHashTable { private: template static char test(typename U::hasher*, typename U::reverse_iterator*); template static int test(typename U::hasher*, ...); template static char test(...); public: static const bool value = sizeof(test(nullptr, nullptr)) == sizeof(int); }; template const bool IsHashTable::value; template (0)) == sizeof(IsContainer)> struct IsRecursiveContainerImpl; template struct IsRecursiveContainerImpl : public std::false_type {}; // Since the IsRecursiveContainerImpl depends on the IsContainerTest we need to // obey the same inconsistencies as the IsContainerTest, namely check if // something is a container is relying on only const_iterator in C++11 and // is relying on both const_iterator and iterator otherwise template struct IsRecursiveContainerImpl { using value_type = decltype(*std::declval()); using type = std::is_same::type>::type, C>; }; // IsRecursiveContainer is a unary compile-time predicate that // evaluates whether C is a recursive container type. A recursive container // type is a container type whose value_type is equal to the container type // itself. An example for a recursive container type is // boost::filesystem::path, whose iterator has a value_type that is equal to // boost::filesystem::path. template struct IsRecursiveContainer : public IsRecursiveContainerImpl::type {}; // Utilities for native arrays. // ArrayEq() compares two k-dimensional native arrays using the // elements' operator==, where k can be any integer >= 0. When k is // 0, ArrayEq() degenerates into comparing a single pair of values. template bool ArrayEq(const T* lhs, size_t size, const U* rhs); // This generic version is used when k is 0. template inline bool ArrayEq(const T& lhs, const U& rhs) { return lhs == rhs; } // This overload is used when k >= 1. template inline bool ArrayEq(const T(&lhs)[N], const U(&rhs)[N]) { return internal::ArrayEq(lhs, N, rhs); } // This helper reduces code bloat. If we instead put its logic inside // the previous ArrayEq() function, arrays with different sizes would // lead to different copies of the template code. template bool ArrayEq(const T* lhs, size_t size, const U* rhs) { for (size_t i = 0; i != size; i++) { if (!internal::ArrayEq(lhs[i], rhs[i])) return false; } return true; } // Finds the first element in the iterator range [begin, end) that // equals elem. Element may be a native array type itself. template Iter ArrayAwareFind(Iter begin, Iter end, const Element& elem) { for (Iter it = begin; it != end; ++it) { if (internal::ArrayEq(*it, elem)) return it; } return end; } // CopyArray() copies a k-dimensional native array using the elements' // operator=, where k can be any integer >= 0. When k is 0, // CopyArray() degenerates into copying a single value. template void CopyArray(const T* from, size_t size, U* to); // This generic version is used when k is 0. template inline void CopyArray(const T& from, U* to) { *to = from; } // This overload is used when k >= 1. template inline void CopyArray(const T(&from)[N], U(*to)[N]) { internal::CopyArray(from, N, *to); } // This helper reduces code bloat. If we instead put its logic inside // the previous CopyArray() function, arrays with different sizes // would lead to different copies of the template code. template void CopyArray(const T* from, size_t size, U* to) { for (size_t i = 0; i != size; i++) { internal::CopyArray(from[i], to + i); } } // The relation between an NativeArray object (see below) and the // native array it represents. // We use 2 different structs to allow non-copyable types to be used, as long // as RelationToSourceReference() is passed. struct RelationToSourceReference {}; struct RelationToSourceCopy {}; // Adapts a native array to a read-only STL-style container. Instead // of the complete STL container concept, this adaptor only implements // members useful for Google Mock's container matchers. New members // should be added as needed. To simplify the implementation, we only // support Element being a raw type (i.e. having no top-level const or // reference modifier). It's the client's responsibility to satisfy // this requirement. Element can be an array type itself (hence // multi-dimensional arrays are supported). template class NativeArray { public: // STL-style container typedefs. typedef Element value_type; typedef Element* iterator; typedef const Element* const_iterator; // Constructs from a native array. References the source. NativeArray(const Element* array, size_t count, RelationToSourceReference) { InitRef(array, count); } // Constructs from a native array. Copies the source. NativeArray(const Element* array, size_t count, RelationToSourceCopy) { InitCopy(array, count); } // Copy constructor. NativeArray(const NativeArray& rhs) { (this->*rhs.clone_)(rhs.array_, rhs.size_); } ~NativeArray() { if (clone_ != &NativeArray::InitRef) delete[] array_; } // STL-style container methods. size_t size() const { return size_; } const_iterator begin() const { return array_; } const_iterator end() const { return array_ + size_; } bool operator==(const NativeArray& rhs) const { return size() == rhs.size() && ArrayEq(begin(), size(), rhs.begin()); } private: static_assert(!std::is_const::value, "Type must not be const"); static_assert(!std::is_reference::value, "Type must not be a reference"); // Initializes this object with a copy of the input. void InitCopy(const Element* array, size_t a_size) { Element* const copy = new Element[a_size]; CopyArray(array, a_size, copy); array_ = copy; size_ = a_size; clone_ = &NativeArray::InitCopy; } // Initializes this object with a reference of the input. void InitRef(const Element* array, size_t a_size) { array_ = array; size_ = a_size; clone_ = &NativeArray::InitRef; } const Element* array_; size_t size_; void (NativeArray::*clone_)(const Element*, size_t); GTEST_DISALLOW_ASSIGN_(NativeArray); }; // Backport of std::index_sequence. template struct IndexSequence { using type = IndexSequence; }; // Double the IndexSequence, and one if plus_one is true. template struct DoubleSequence; template struct DoubleSequence, sizeofT> { using type = IndexSequence; }; template struct DoubleSequence, sizeofT> { using type = IndexSequence; }; // Backport of std::make_index_sequence. // It uses O(ln(N)) instantiation depth. template struct MakeIndexSequence : DoubleSequence::type, N / 2>::type {}; template <> struct MakeIndexSequence<0> : IndexSequence<> {}; // FIXME: This implementation of ElemFromList is O(1) in instantiation depth, // but it is O(N^2) in total instantiations. Not sure if this is the best // tradeoff, as it will make it somewhat slow to compile. template struct ElemFromListImpl {}; template struct ElemFromListImpl { using type = T; }; // Get the Nth element from T... // It uses O(1) instantiation depth. template struct ElemFromList; template struct ElemFromList, T...> : ElemFromListImpl... {}; template class FlatTuple; template struct FlatTupleElemBase; template struct FlatTupleElemBase, I> { using value_type = typename ElemFromList::type, T...>::type; FlatTupleElemBase() = default; explicit FlatTupleElemBase(value_type t) : value(std::move(t)) {} value_type value; }; template struct FlatTupleBase; template struct FlatTupleBase, IndexSequence> : FlatTupleElemBase, Idx>... { using Indices = IndexSequence; FlatTupleBase() = default; explicit FlatTupleBase(T... t) : FlatTupleElemBase, Idx>(std::move(t))... {} }; // Analog to std::tuple but with different tradeoffs. // This class minimizes the template instantiation depth, thus allowing more // elements that std::tuple would. std::tuple has been seen to require an // instantiation depth of more than 10x the number of elements in some // implementations. // FlatTuple and ElemFromList are not recursive and have a fixed depth // regardless of T... // MakeIndexSequence, on the other hand, it is recursive but with an // instantiation depth of O(ln(N)). template class FlatTuple : private FlatTupleBase, typename MakeIndexSequence::type> { using Indices = typename FlatTuple::FlatTupleBase::Indices; public: FlatTuple() = default; explicit FlatTuple(T... t) : FlatTuple::FlatTupleBase(std::move(t)...) {} template const typename ElemFromList::type& Get() const { return static_cast*>(this)->value; } template typename ElemFromList::type& Get() { return static_cast*>(this)->value; } }; // Utility functions to be called with static_assert to induce deprecation // warnings. GTEST_INTERNAL_DEPRECATED( "INSTANTIATE_TEST_CASE_P is deprecated, please use " "INSTANTIATE_TEST_SUITE_P") constexpr bool InstantiateTestCase_P_IsDeprecated() { return true; } GTEST_INTERNAL_DEPRECATED( "TYPED_TEST_CASE_P is deprecated, please use " "TYPED_TEST_SUITE_P") constexpr bool TypedTestCase_P_IsDeprecated() { return true; } GTEST_INTERNAL_DEPRECATED( "TYPED_TEST_CASE is deprecated, please use " "TYPED_TEST_SUITE") constexpr bool TypedTestCaseIsDeprecated() { return true; } GTEST_INTERNAL_DEPRECATED( "REGISTER_TYPED_TEST_CASE_P is deprecated, please use " "REGISTER_TYPED_TEST_SUITE_P") constexpr bool RegisterTypedTestCase_P_IsDeprecated() { return true; } GTEST_INTERNAL_DEPRECATED( "INSTANTIATE_TYPED_TEST_CASE_P is deprecated, please use " "INSTANTIATE_TYPED_TEST_SUITE_P") constexpr bool InstantiateTypedTestCase_P_IsDeprecated() { return true; } } // namespace internal } // namespace testing #define GTEST_MESSAGE_AT_(file, line, message, result_type) \ ::testing::internal::AssertHelper(result_type, file, line, message) \ = ::testing::Message() #define GTEST_MESSAGE_(message, result_type) \ GTEST_MESSAGE_AT_(__FILE__, __LINE__, message, result_type) #define GTEST_FATAL_FAILURE_(message) \ return GTEST_MESSAGE_(message, ::testing::TestPartResult::kFatalFailure) #define GTEST_NONFATAL_FAILURE_(message) \ GTEST_MESSAGE_(message, ::testing::TestPartResult::kNonFatalFailure) #define GTEST_SUCCESS_(message) \ GTEST_MESSAGE_(message, ::testing::TestPartResult::kSuccess) #define GTEST_SKIP_(message) \ return GTEST_MESSAGE_(message, ::testing::TestPartResult::kSkip) // Suppress MSVC warning 4072 (unreachable code) for the code following // statement if it returns or throws (or doesn't return or throw in some // situations). #define GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) \ if (::testing::internal::AlwaysTrue()) { statement; } #define GTEST_TEST_THROW_(statement, expected_exception, fail) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::ConstCharPtr gtest_msg = "") { \ bool gtest_caught_expected = false; \ try { \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ } \ catch (expected_exception const&) { \ gtest_caught_expected = true; \ } \ catch (...) { \ gtest_msg.value = \ "Expected: " #statement " throws an exception of type " \ #expected_exception ".\n Actual: it throws a different type."; \ goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ } \ if (!gtest_caught_expected) { \ gtest_msg.value = \ "Expected: " #statement " throws an exception of type " \ #expected_exception ".\n Actual: it throws nothing."; \ goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ } \ } else \ GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__): \ fail(gtest_msg.value) #define GTEST_TEST_NO_THROW_(statement, fail) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::AlwaysTrue()) { \ try { \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ } \ catch (...) { \ goto GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__); \ } \ } else \ GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__): \ fail("Expected: " #statement " doesn't throw an exception.\n" \ " Actual: it throws.") #define GTEST_TEST_ANY_THROW_(statement, fail) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::AlwaysTrue()) { \ bool gtest_caught_any = false; \ try { \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ } \ catch (...) { \ gtest_caught_any = true; \ } \ if (!gtest_caught_any) { \ goto GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__); \ } \ } else \ GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__): \ fail("Expected: " #statement " throws an exception.\n" \ " Actual: it doesn't.") // Implements Boolean test assertions such as EXPECT_TRUE. expression can be // either a boolean expression or an AssertionResult. text is a textual // represenation of expression as it was passed into the EXPECT_TRUE. #define GTEST_TEST_BOOLEAN_(expression, text, actual, expected, fail) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (const ::testing::AssertionResult gtest_ar_ = \ ::testing::AssertionResult(expression)) \ ; \ else \ fail(::testing::internal::GetBoolAssertionFailureMessage(\ gtest_ar_, text, #actual, #expected).c_str()) #define GTEST_TEST_NO_FATAL_FAILURE_(statement, fail) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::AlwaysTrue()) { \ ::testing::internal::HasNewFatalFailureHelper gtest_fatal_failure_checker; \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ if (gtest_fatal_failure_checker.has_new_fatal_failure()) { \ goto GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__); \ } \ } else \ GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__): \ fail("Expected: " #statement " doesn't generate new fatal " \ "failures in the current thread.\n" \ " Actual: it does.") // Expands to the name of the class that implements the given test. #define GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \ test_suite_name##_##test_name##_Test // Helper macro for defining tests. #define GTEST_TEST_(test_suite_name, test_name, parent_class, parent_id) \ static_assert(sizeof(GTEST_STRINGIFY_(test_suite_name)) > 1, \ "test_suite_name must not be empty"); \ static_assert(sizeof(GTEST_STRINGIFY_(test_name)) > 1, \ "test_name must not be empty"); \ class GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \ : public parent_class { \ public: \ GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)() {} \ \ private: \ virtual void TestBody(); \ static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_suite_name, \ test_name)); \ }; \ \ ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_suite_name, \ test_name)::test_info_ = \ ::testing::internal::MakeAndRegisterTestInfo( \ #test_suite_name, #test_name, nullptr, nullptr, \ ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ ::testing::internal::SuiteApiResolver< \ parent_class>::GetSetUpCaseOrSuite(__FILE__, __LINE__), \ ::testing::internal::SuiteApiResolver< \ parent_class>::GetTearDownCaseOrSuite(__FILE__, __LINE__), \ new ::testing::internal::TestFactoryImpl); \ void GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::TestBody() #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file defines the public API for death tests. It is // #included by gtest.h so a user doesn't need to include this // directly. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ #define GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ // Copyright 2005, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // The Google C++ Testing and Mocking Framework (Google Test) // // This header file defines internal utilities needed for implementing // death tests. They are subject to change without notice. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ // Copyright 2007, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // The Google C++ Testing and Mocking Framework (Google Test) // // This file implements just enough of the matcher interface to allow // EXPECT_DEATH and friends to accept a matcher argument. // IWYU pragma: private, include "testing/base/public/gunit.h" // IWYU pragma: friend third_party/googletest/googlemock/.* // IWYU pragma: friend third_party/googletest/googletest/.* #ifndef GTEST_INCLUDE_GTEST_GTEST_MATCHERS_H_ #define GTEST_INCLUDE_GTEST_GTEST_MATCHERS_H_ #include #include #include #include // Copyright 2007, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Google Test - The Google C++ Testing and Mocking Framework // // This file implements a universal value printer that can print a // value of any type T: // // void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); // // A user can teach this function how to print a class type T by // defining either operator<<() or PrintTo() in the namespace that // defines T. More specifically, the FIRST defined function in the // following list will be used (assuming T is defined in namespace // foo): // // 1. foo::PrintTo(const T&, ostream*) // 2. operator<<(ostream&, const T&) defined in either foo or the // global namespace. // // However if T is an STL-style container then it is printed element-wise // unless foo::PrintTo(const T&, ostream*) is defined. Note that // operator<<() is ignored for container types. // // If none of the above is defined, it will print the debug string of // the value if it is a protocol buffer, or print the raw bytes in the // value otherwise. // // To aid debugging: when T is a reference type, the address of the // value is also printed; when T is a (const) char pointer, both the // pointer value and the NUL-terminated string it points to are // printed. // // We also provide some convenient wrappers: // // // Prints a value to a string. For a (const or not) char // // pointer, the NUL-terminated string (but not the pointer) is // // printed. // std::string ::testing::PrintToString(const T& value); // // // Prints a value tersely: for a reference type, the referenced // // value (but not the address) is printed; for a (const or not) char // // pointer, the NUL-terminated string (but not the pointer) is // // printed. // void ::testing::internal::UniversalTersePrint(const T& value, ostream*); // // // Prints value using the type inferred by the compiler. The difference // // from UniversalTersePrint() is that this function prints both the // // pointer and the NUL-terminated string for a (const or not) char pointer. // void ::testing::internal::UniversalPrint(const T& value, ostream*); // // // Prints the fields of a tuple tersely to a string vector, one // // element for each field. Tuple support must be enabled in // // gtest-port.h. // std::vector UniversalTersePrintTupleFieldsToStrings( // const Tuple& value); // // Known limitation: // // The print primitives print the elements of an STL-style container // using the compiler-inferred type of *iter where iter is a // const_iterator of the container. When const_iterator is an input // iterator but not a forward iterator, this inferred type may not // match value_type, and the print output may be incorrect. In // practice, this is rarely a problem as for most containers // const_iterator is a forward iterator. We'll fix this if there's an // actual need for it. Note that this fix cannot rely on value_type // being defined as many user-defined container types don't have // value_type. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ #define GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ #include #include // NOLINT #include #include #include #include #include #include #if GTEST_HAS_ABSL #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #endif // GTEST_HAS_ABSL namespace testing { // Definitions in the 'internal' and 'internal2' name spaces are // subject to change without notice. DO NOT USE THEM IN USER CODE! namespace internal2 { // Prints the given number of bytes in the given object to the given // ostream. GTEST_API_ void PrintBytesInObjectTo(const unsigned char* obj_bytes, size_t count, ::std::ostream* os); // For selecting which printer to use when a given type has neither << // nor PrintTo(). enum TypeKind { kProtobuf, // a protobuf type kConvertibleToInteger, // a type implicitly convertible to BiggestInt // (e.g. a named or unnamed enum type) #if GTEST_HAS_ABSL kConvertibleToStringView, // a type implicitly convertible to // absl::string_view #endif kOtherType // anything else }; // TypeWithoutFormatter::PrintValue(value, os) is called // by the universal printer to print a value of type T when neither // operator<< nor PrintTo() is defined for T, where kTypeKind is the // "kind" of T as defined by enum TypeKind. template class TypeWithoutFormatter { public: // This default version is called when kTypeKind is kOtherType. static void PrintValue(const T& value, ::std::ostream* os) { PrintBytesInObjectTo( static_cast( reinterpret_cast(std::addressof(value))), sizeof(value), os); } }; // We print a protobuf using its ShortDebugString() when the string // doesn't exceed this many characters; otherwise we print it using // DebugString() for better readability. const size_t kProtobufOneLinerMaxLength = 50; template class TypeWithoutFormatter { public: static void PrintValue(const T& value, ::std::ostream* os) { std::string pretty_str = value.ShortDebugString(); if (pretty_str.length() > kProtobufOneLinerMaxLength) { pretty_str = "\n" + value.DebugString(); } *os << ("<" + pretty_str + ">"); } }; template class TypeWithoutFormatter { public: // Since T has no << operator or PrintTo() but can be implicitly // converted to BiggestInt, we print it as a BiggestInt. // // Most likely T is an enum type (either named or unnamed), in which // case printing it as an integer is the desired behavior. In case // T is not an enum, printing it as an integer is the best we can do // given that it has no user-defined printer. static void PrintValue(const T& value, ::std::ostream* os) { const internal::BiggestInt kBigInt = value; *os << kBigInt; } }; #if GTEST_HAS_ABSL template class TypeWithoutFormatter { public: // Since T has neither operator<< nor PrintTo() but can be implicitly // converted to absl::string_view, we print it as a absl::string_view. // // Note: the implementation is further below, as it depends on // internal::PrintTo symbol which is defined later in the file. static void PrintValue(const T& value, ::std::ostream* os); }; #endif // Prints the given value to the given ostream. If the value is a // protocol message, its debug string is printed; if it's an enum or // of a type implicitly convertible to BiggestInt, it's printed as an // integer; otherwise the bytes in the value are printed. This is // what UniversalPrinter::Print() does when it knows nothing about // type T and T has neither << operator nor PrintTo(). // // A user can override this behavior for a class type Foo by defining // a << operator in the namespace where Foo is defined. // // We put this operator in namespace 'internal2' instead of 'internal' // to simplify the implementation, as much code in 'internal' needs to // use << in STL, which would conflict with our own << were it defined // in 'internal'. // // Note that this operator<< takes a generic std::basic_ostream type instead of the more restricted std::ostream. If // we define it to take an std::ostream instead, we'll get an // "ambiguous overloads" compiler error when trying to print a type // Foo that supports streaming to std::basic_ostream, as the compiler cannot tell whether // operator<<(std::ostream&, const T&) or // operator<<(std::basic_stream, const Foo&) is more // specific. template ::std::basic_ostream& operator<<( ::std::basic_ostream& os, const T& x) { TypeWithoutFormatter::value ? kProtobuf : std::is_convertible< const T&, internal::BiggestInt>::value ? kConvertibleToInteger : #if GTEST_HAS_ABSL std::is_convertible< const T&, absl::string_view>::value ? kConvertibleToStringView : #endif kOtherType)>::PrintValue(x, &os); return os; } } // namespace internal2 } // namespace testing // This namespace MUST NOT BE NESTED IN ::testing, or the name look-up // magic needed for implementing UniversalPrinter won't work. namespace testing_internal { // Used to print a value that is not an STL-style container when the // user doesn't define PrintTo() for it. template void DefaultPrintNonContainerTo(const T& value, ::std::ostream* os) { // With the following statement, during unqualified name lookup, // testing::internal2::operator<< appears as if it was declared in // the nearest enclosing namespace that contains both // ::testing_internal and ::testing::internal2, i.e. the global // namespace. For more details, refer to the C++ Standard section // 7.3.4-1 [namespace.udir]. This allows us to fall back onto // testing::internal2::operator<< in case T doesn't come with a << // operator. // // We cannot write 'using ::testing::internal2::operator<<;', which // gcc 3.3 fails to compile due to a compiler bug. using namespace ::testing::internal2; // NOLINT // Assuming T is defined in namespace foo, in the next statement, // the compiler will consider all of: // // 1. foo::operator<< (thanks to Koenig look-up), // 2. ::operator<< (as the current namespace is enclosed in ::), // 3. testing::internal2::operator<< (thanks to the using statement above). // // The operator<< whose type matches T best will be picked. // // We deliberately allow #2 to be a candidate, as sometimes it's // impossible to define #1 (e.g. when foo is ::std, defining // anything in it is undefined behavior unless you are a compiler // vendor.). *os << value; } } // namespace testing_internal namespace testing { namespace internal { // FormatForComparison::Format(value) formats a // value of type ToPrint that is an operand of a comparison assertion // (e.g. ASSERT_EQ). OtherOperand is the type of the other operand in // the comparison, and is used to help determine the best way to // format the value. In particular, when the value is a C string // (char pointer) and the other operand is an STL string object, we // want to format the C string as a string, since we know it is // compared by value with the string object. If the value is a char // pointer but the other operand is not an STL string object, we don't // know whether the pointer is supposed to point to a NUL-terminated // string, and thus want to print it as a pointer to be safe. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. // The default case. template class FormatForComparison { public: static ::std::string Format(const ToPrint& value) { return ::testing::PrintToString(value); } }; // Array. template class FormatForComparison { public: static ::std::string Format(const ToPrint* value) { return FormatForComparison::Format(value); } }; // By default, print C string as pointers to be safe, as we don't know // whether they actually point to a NUL-terminated string. #define GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(CharType) \ template \ class FormatForComparison { \ public: \ static ::std::string Format(CharType* value) { \ return ::testing::PrintToString(static_cast(value)); \ } \ } GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(char); GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const char); GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(wchar_t); GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const wchar_t); #undef GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_ // If a C string is compared with an STL string object, we know it's meant // to point to a NUL-terminated string, and thus can print it as a string. #define GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(CharType, OtherStringType) \ template <> \ class FormatForComparison { \ public: \ static ::std::string Format(CharType* value) { \ return ::testing::PrintToString(value); \ } \ } GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(char, ::std::string); GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const char, ::std::string); #if GTEST_HAS_STD_WSTRING GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(wchar_t, ::std::wstring); GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const wchar_t, ::std::wstring); #endif #undef GTEST_IMPL_FORMAT_C_STRING_AS_STRING_ // Formats a comparison assertion (e.g. ASSERT_EQ, EXPECT_LT, and etc) // operand to be used in a failure message. The type (but not value) // of the other operand may affect the format. This allows us to // print a char* as a raw pointer when it is compared against another // char* or void*, and print it as a C string when it is compared // against an std::string object, for example. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. template std::string FormatForComparisonFailureMessage( const T1& value, const T2& /* other_operand */) { return FormatForComparison::Format(value); } // UniversalPrinter::Print(value, ostream_ptr) prints the given // value to the given ostream. The caller must ensure that // 'ostream_ptr' is not NULL, or the behavior is undefined. // // We define UniversalPrinter as a class template (as opposed to a // function template), as we need to partially specialize it for // reference types, which cannot be done with function templates. template class UniversalPrinter; template void UniversalPrint(const T& value, ::std::ostream* os); enum DefaultPrinterType { kPrintContainer, kPrintPointer, kPrintFunctionPointer, kPrintOther, }; template struct WrapPrinterType {}; // Used to print an STL-style container when the user doesn't define // a PrintTo() for it. template void DefaultPrintTo(WrapPrinterType /* dummy */, const C& container, ::std::ostream* os) { const size_t kMaxCount = 32; // The maximum number of elements to print. *os << '{'; size_t count = 0; for (typename C::const_iterator it = container.begin(); it != container.end(); ++it, ++count) { if (count > 0) { *os << ','; if (count == kMaxCount) { // Enough has been printed. *os << " ..."; break; } } *os << ' '; // We cannot call PrintTo(*it, os) here as PrintTo() doesn't // handle *it being a native array. internal::UniversalPrint(*it, os); } if (count > 0) { *os << ' '; } *os << '}'; } // Used to print a pointer that is neither a char pointer nor a member // pointer, when the user doesn't define PrintTo() for it. (A member // variable pointer or member function pointer doesn't really point to // a location in the address space. Their representation is // implementation-defined. Therefore they will be printed as raw // bytes.) template void DefaultPrintTo(WrapPrinterType /* dummy */, T* p, ::std::ostream* os) { if (p == nullptr) { *os << "NULL"; } else { // T is not a function type. We just call << to print p, // relying on ADL to pick up user-defined << for their pointer // types, if any. *os << p; } } template void DefaultPrintTo(WrapPrinterType /* dummy */, T* p, ::std::ostream* os) { if (p == nullptr) { *os << "NULL"; } else { // T is a function type, so '*os << p' doesn't do what we want // (it just prints p as bool). We want to print p as a const // void*. *os << reinterpret_cast(p); } } // Used to print a non-container, non-pointer value when the user // doesn't define PrintTo() for it. template void DefaultPrintTo(WrapPrinterType /* dummy */, const T& value, ::std::ostream* os) { ::testing_internal::DefaultPrintNonContainerTo(value, os); } // Prints the given value using the << operator if it has one; // otherwise prints the bytes in it. This is what // UniversalPrinter::Print() does when PrintTo() is not specialized // or overloaded for type T. // // A user can override this behavior for a class type Foo by defining // an overload of PrintTo() in the namespace where Foo is defined. We // give the user this option as sometimes defining a << operator for // Foo is not desirable (e.g. the coding style may prevent doing it, // or there is already a << operator but it doesn't do what the user // wants). template void PrintTo(const T& value, ::std::ostream* os) { // DefaultPrintTo() is overloaded. The type of its first argument // determines which version will be picked. // // Note that we check for container types here, prior to we check // for protocol message types in our operator<<. The rationale is: // // For protocol messages, we want to give people a chance to // override Google Mock's format by defining a PrintTo() or // operator<<. For STL containers, other formats can be // incompatible with Google Mock's format for the container // elements; therefore we check for container types here to ensure // that our format is used. // // Note that MSVC and clang-cl do allow an implicit conversion from // pointer-to-function to pointer-to-object, but clang-cl warns on it. // So don't use ImplicitlyConvertible if it can be helped since it will // cause this warning, and use a separate overload of DefaultPrintTo for // function pointers so that the `*os << p` in the object pointer overload // doesn't cause that warning either. DefaultPrintTo( WrapPrinterType < (sizeof(IsContainerTest(0)) == sizeof(IsContainer)) && !IsRecursiveContainer::value ? kPrintContainer : !std::is_pointer::value ? kPrintOther : std::is_function::type>::value ? kPrintFunctionPointer : kPrintPointer > (), value, os); } // The following list of PrintTo() overloads tells // UniversalPrinter::Print() how to print standard types (built-in // types, strings, plain arrays, and pointers). // Overloads for various char types. GTEST_API_ void PrintTo(unsigned char c, ::std::ostream* os); GTEST_API_ void PrintTo(signed char c, ::std::ostream* os); inline void PrintTo(char c, ::std::ostream* os) { // When printing a plain char, we always treat it as unsigned. This // way, the output won't be affected by whether the compiler thinks // char is signed or not. PrintTo(static_cast(c), os); } // Overloads for other simple built-in types. inline void PrintTo(bool x, ::std::ostream* os) { *os << (x ? "true" : "false"); } // Overload for wchar_t type. // Prints a wchar_t as a symbol if it is printable or as its internal // code otherwise and also as its decimal code (except for L'\0'). // The L'\0' char is printed as "L'\\0'". The decimal code is printed // as signed integer when wchar_t is implemented by the compiler // as a signed type and is printed as an unsigned integer when wchar_t // is implemented as an unsigned type. GTEST_API_ void PrintTo(wchar_t wc, ::std::ostream* os); // Overloads for C strings. GTEST_API_ void PrintTo(const char* s, ::std::ostream* os); inline void PrintTo(char* s, ::std::ostream* os) { PrintTo(ImplicitCast_(s), os); } // signed/unsigned char is often used for representing binary data, so // we print pointers to it as void* to be safe. inline void PrintTo(const signed char* s, ::std::ostream* os) { PrintTo(ImplicitCast_(s), os); } inline void PrintTo(signed char* s, ::std::ostream* os) { PrintTo(ImplicitCast_(s), os); } inline void PrintTo(const unsigned char* s, ::std::ostream* os) { PrintTo(ImplicitCast_(s), os); } inline void PrintTo(unsigned char* s, ::std::ostream* os) { PrintTo(ImplicitCast_(s), os); } // MSVC can be configured to define wchar_t as a typedef of unsigned // short. It defines _NATIVE_WCHAR_T_DEFINED when wchar_t is a native // type. When wchar_t is a typedef, defining an overload for const // wchar_t* would cause unsigned short* be printed as a wide string, // possibly causing invalid memory accesses. #if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) // Overloads for wide C strings GTEST_API_ void PrintTo(const wchar_t* s, ::std::ostream* os); inline void PrintTo(wchar_t* s, ::std::ostream* os) { PrintTo(ImplicitCast_(s), os); } #endif // Overload for C arrays. Multi-dimensional arrays are printed // properly. // Prints the given number of elements in an array, without printing // the curly braces. template void PrintRawArrayTo(const T a[], size_t count, ::std::ostream* os) { UniversalPrint(a[0], os); for (size_t i = 1; i != count; i++) { *os << ", "; UniversalPrint(a[i], os); } } // Overloads for ::std::string. GTEST_API_ void PrintStringTo(const ::std::string&s, ::std::ostream* os); inline void PrintTo(const ::std::string& s, ::std::ostream* os) { PrintStringTo(s, os); } // Overloads for ::std::wstring. #if GTEST_HAS_STD_WSTRING GTEST_API_ void PrintWideStringTo(const ::std::wstring&s, ::std::ostream* os); inline void PrintTo(const ::std::wstring& s, ::std::ostream* os) { PrintWideStringTo(s, os); } #endif // GTEST_HAS_STD_WSTRING #if GTEST_HAS_ABSL // Overload for absl::string_view. inline void PrintTo(absl::string_view sp, ::std::ostream* os) { PrintTo(::std::string(sp), os); } #endif // GTEST_HAS_ABSL inline void PrintTo(std::nullptr_t, ::std::ostream* os) { *os << "(nullptr)"; } template void PrintTo(std::reference_wrapper ref, ::std::ostream* os) { UniversalPrinter::Print(ref.get(), os); } // Helper function for printing a tuple. T must be instantiated with // a tuple type. template void PrintTupleTo(const T&, std::integral_constant, ::std::ostream*) {} template void PrintTupleTo(const T& t, std::integral_constant, ::std::ostream* os) { PrintTupleTo(t, std::integral_constant(), os); GTEST_INTENTIONAL_CONST_COND_PUSH_() if (I > 1) { GTEST_INTENTIONAL_CONST_COND_POP_() *os << ", "; } UniversalPrinter::type>::Print( std::get(t), os); } template void PrintTo(const ::std::tuple& t, ::std::ostream* os) { *os << "("; PrintTupleTo(t, std::integral_constant(), os); *os << ")"; } // Overload for std::pair. template void PrintTo(const ::std::pair& value, ::std::ostream* os) { *os << '('; // We cannot use UniversalPrint(value.first, os) here, as T1 may be // a reference type. The same for printing value.second. UniversalPrinter::Print(value.first, os); *os << ", "; UniversalPrinter::Print(value.second, os); *os << ')'; } // Implements printing a non-reference type T by letting the compiler // pick the right overload of PrintTo() for T. template class UniversalPrinter { public: // MSVC warns about adding const to a function type, so we want to // disable the warning. GTEST_DISABLE_MSC_WARNINGS_PUSH_(4180) // Note: we deliberately don't call this PrintTo(), as that name // conflicts with ::testing::internal::PrintTo in the body of the // function. static void Print(const T& value, ::std::ostream* os) { // By default, ::testing::internal::PrintTo() is used for printing // the value. // // Thanks to Koenig look-up, if T is a class and has its own // PrintTo() function defined in its namespace, that function will // be visible here. Since it is more specific than the generic ones // in ::testing::internal, it will be picked by the compiler in the // following statement - exactly what we want. PrintTo(value, os); } GTEST_DISABLE_MSC_WARNINGS_POP_() }; #if GTEST_HAS_ABSL // Printer for absl::optional template class UniversalPrinter<::absl::optional> { public: static void Print(const ::absl::optional& value, ::std::ostream* os) { *os << '('; if (!value) { *os << "nullopt"; } else { UniversalPrint(*value, os); } *os << ')'; } }; // Printer for absl::variant template class UniversalPrinter<::absl::variant> { public: static void Print(const ::absl::variant& value, ::std::ostream* os) { *os << '('; absl::visit(Visitor{os}, value); *os << ')'; } private: struct Visitor { template void operator()(const U& u) const { *os << "'" << GetTypeName() << "' with value "; UniversalPrint(u, os); } ::std::ostream* os; }; }; #endif // GTEST_HAS_ABSL // UniversalPrintArray(begin, len, os) prints an array of 'len' // elements, starting at address 'begin'. template void UniversalPrintArray(const T* begin, size_t len, ::std::ostream* os) { if (len == 0) { *os << "{}"; } else { *os << "{ "; const size_t kThreshold = 18; const size_t kChunkSize = 8; // If the array has more than kThreshold elements, we'll have to // omit some details by printing only the first and the last // kChunkSize elements. if (len <= kThreshold) { PrintRawArrayTo(begin, len, os); } else { PrintRawArrayTo(begin, kChunkSize, os); *os << ", ..., "; PrintRawArrayTo(begin + len - kChunkSize, kChunkSize, os); } *os << " }"; } } // This overload prints a (const) char array compactly. GTEST_API_ void UniversalPrintArray( const char* begin, size_t len, ::std::ostream* os); // This overload prints a (const) wchar_t array compactly. GTEST_API_ void UniversalPrintArray( const wchar_t* begin, size_t len, ::std::ostream* os); // Implements printing an array type T[N]. template class UniversalPrinter { public: // Prints the given array, omitting some elements when there are too // many. static void Print(const T (&a)[N], ::std::ostream* os) { UniversalPrintArray(a, N, os); } }; // Implements printing a reference type T&. template class UniversalPrinter { public: // MSVC warns about adding const to a function type, so we want to // disable the warning. GTEST_DISABLE_MSC_WARNINGS_PUSH_(4180) static void Print(const T& value, ::std::ostream* os) { // Prints the address of the value. We use reinterpret_cast here // as static_cast doesn't compile when T is a function type. *os << "@" << reinterpret_cast(&value) << " "; // Then prints the value itself. UniversalPrint(value, os); } GTEST_DISABLE_MSC_WARNINGS_POP_() }; // Prints a value tersely: for a reference type, the referenced value // (but not the address) is printed; for a (const) char pointer, the // NUL-terminated string (but not the pointer) is printed. template class UniversalTersePrinter { public: static void Print(const T& value, ::std::ostream* os) { UniversalPrint(value, os); } }; template class UniversalTersePrinter { public: static void Print(const T& value, ::std::ostream* os) { UniversalPrint(value, os); } }; template class UniversalTersePrinter { public: static void Print(const T (&value)[N], ::std::ostream* os) { UniversalPrinter::Print(value, os); } }; template <> class UniversalTersePrinter { public: static void Print(const char* str, ::std::ostream* os) { if (str == nullptr) { *os << "NULL"; } else { UniversalPrint(std::string(str), os); } } }; template <> class UniversalTersePrinter { public: static void Print(char* str, ::std::ostream* os) { UniversalTersePrinter::Print(str, os); } }; #if GTEST_HAS_STD_WSTRING template <> class UniversalTersePrinter { public: static void Print(const wchar_t* str, ::std::ostream* os) { if (str == nullptr) { *os << "NULL"; } else { UniversalPrint(::std::wstring(str), os); } } }; #endif template <> class UniversalTersePrinter { public: static void Print(wchar_t* str, ::std::ostream* os) { UniversalTersePrinter::Print(str, os); } }; template void UniversalTersePrint(const T& value, ::std::ostream* os) { UniversalTersePrinter::Print(value, os); } // Prints a value using the type inferred by the compiler. The // difference between this and UniversalTersePrint() is that for a // (const) char pointer, this prints both the pointer and the // NUL-terminated string. template void UniversalPrint(const T& value, ::std::ostream* os) { // A workarond for the bug in VC++ 7.1 that prevents us from instantiating // UniversalPrinter with T directly. typedef T T1; UniversalPrinter::Print(value, os); } typedef ::std::vector< ::std::string> Strings; // Tersely prints the first N fields of a tuple to a string vector, // one element for each field. template void TersePrintPrefixToStrings(const Tuple&, std::integral_constant, Strings*) {} template void TersePrintPrefixToStrings(const Tuple& t, std::integral_constant, Strings* strings) { TersePrintPrefixToStrings(t, std::integral_constant(), strings); ::std::stringstream ss; UniversalTersePrint(std::get(t), &ss); strings->push_back(ss.str()); } // Prints the fields of a tuple tersely to a string vector, one // element for each field. See the comment before // UniversalTersePrint() for how we define "tersely". template Strings UniversalTersePrintTupleFieldsToStrings(const Tuple& value) { Strings result; TersePrintPrefixToStrings( value, std::integral_constant::value>(), &result); return result; } } // namespace internal #if GTEST_HAS_ABSL namespace internal2 { template void TypeWithoutFormatter::PrintValue( const T& value, ::std::ostream* os) { internal::PrintTo(absl::string_view(value), os); } } // namespace internal2 #endif template ::std::string PrintToString(const T& value) { ::std::stringstream ss; internal::UniversalTersePrinter::Print(value, &ss); return ss.str(); } } // namespace testing // Include any custom printer added by the local installation. // We must include this header at the end to make sure it can use the // declarations from this file. // Copyright 2015, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // This file provides an injection point for custom printers in a local // installation of gTest. // It will be included from gtest-printers.h and the overrides in this file // will be visible to everyone. // // Injection point for custom user configurations. See README for details // // ** Custom implementation starts here ** #ifndef GTEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ #endif // GTEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ #endif // GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ // MSVC warning C5046 is new as of VS2017 version 15.8. #if defined(_MSC_VER) && _MSC_VER >= 1915 #define GTEST_MAYBE_5046_ 5046 #else #define GTEST_MAYBE_5046_ #endif GTEST_DISABLE_MSC_WARNINGS_PUSH_( 4251 GTEST_MAYBE_5046_ /* class A needs to have dll-interface to be used by clients of class B */ /* Symbol involving type with internal linkage not defined */) namespace testing { // To implement a matcher Foo for type T, define: // 1. a class FooMatcherImpl that implements the // MatcherInterface interface, and // 2. a factory function that creates a Matcher object from a // FooMatcherImpl*. // // The two-level delegation design makes it possible to allow a user // to write "v" instead of "Eq(v)" where a Matcher is expected, which // is impossible if we pass matchers by pointers. It also eases // ownership management as Matcher objects can now be copied like // plain values. // MatchResultListener is an abstract class. Its << operator can be // used by a matcher to explain why a value matches or doesn't match. // class MatchResultListener { public: // Creates a listener object with the given underlying ostream. The // listener does not own the ostream, and does not dereference it // in the constructor or destructor. explicit MatchResultListener(::std::ostream* os) : stream_(os) {} virtual ~MatchResultListener() = 0; // Makes this class abstract. // Streams x to the underlying ostream; does nothing if the ostream // is NULL. template MatchResultListener& operator<<(const T& x) { if (stream_ != nullptr) *stream_ << x; return *this; } // Returns the underlying ostream. ::std::ostream* stream() { return stream_; } // Returns true if and only if the listener is interested in an explanation // of the match result. A matcher's MatchAndExplain() method can use // this information to avoid generating the explanation when no one // intends to hear it. bool IsInterested() const { return stream_ != nullptr; } private: ::std::ostream* const stream_; GTEST_DISALLOW_COPY_AND_ASSIGN_(MatchResultListener); }; inline MatchResultListener::~MatchResultListener() { } // An instance of a subclass of this knows how to describe itself as a // matcher. class MatcherDescriberInterface { public: virtual ~MatcherDescriberInterface() {} // Describes this matcher to an ostream. The function should print // a verb phrase that describes the property a value matching this // matcher should have. The subject of the verb phrase is the value // being matched. For example, the DescribeTo() method of the Gt(7) // matcher prints "is greater than 7". virtual void DescribeTo(::std::ostream* os) const = 0; // Describes the negation of this matcher to an ostream. For // example, if the description of this matcher is "is greater than // 7", the negated description could be "is not greater than 7". // You are not required to override this when implementing // MatcherInterface, but it is highly advised so that your matcher // can produce good error messages. virtual void DescribeNegationTo(::std::ostream* os) const { *os << "not ("; DescribeTo(os); *os << ")"; } }; // The implementation of a matcher. template class MatcherInterface : public MatcherDescriberInterface { public: // Returns true if and only if the matcher matches x; also explains the // match result to 'listener' if necessary (see the next paragraph), in // the form of a non-restrictive relative clause ("which ...", // "whose ...", etc) that describes x. For example, the // MatchAndExplain() method of the Pointee(...) matcher should // generate an explanation like "which points to ...". // // Implementations of MatchAndExplain() should add an explanation of // the match result *if and only if* they can provide additional // information that's not already present (or not obvious) in the // print-out of x and the matcher's description. Whether the match // succeeds is not a factor in deciding whether an explanation is // needed, as sometimes the caller needs to print a failure message // when the match succeeds (e.g. when the matcher is used inside // Not()). // // For example, a "has at least 10 elements" matcher should explain // what the actual element count is, regardless of the match result, // as it is useful information to the reader; on the other hand, an // "is empty" matcher probably only needs to explain what the actual // size is when the match fails, as it's redundant to say that the // size is 0 when the value is already known to be empty. // // You should override this method when defining a new matcher. // // It's the responsibility of the caller (Google Test) to guarantee // that 'listener' is not NULL. This helps to simplify a matcher's // implementation when it doesn't care about the performance, as it // can talk to 'listener' without checking its validity first. // However, in order to implement dummy listeners efficiently, // listener->stream() may be NULL. virtual bool MatchAndExplain(T x, MatchResultListener* listener) const = 0; // Inherits these methods from MatcherDescriberInterface: // virtual void DescribeTo(::std::ostream* os) const = 0; // virtual void DescribeNegationTo(::std::ostream* os) const; }; namespace internal { // Converts a MatcherInterface to a MatcherInterface. template class MatcherInterfaceAdapter : public MatcherInterface { public: explicit MatcherInterfaceAdapter(const MatcherInterface* impl) : impl_(impl) {} ~MatcherInterfaceAdapter() override { delete impl_; } void DescribeTo(::std::ostream* os) const override { impl_->DescribeTo(os); } void DescribeNegationTo(::std::ostream* os) const override { impl_->DescribeNegationTo(os); } bool MatchAndExplain(const T& x, MatchResultListener* listener) const override { return impl_->MatchAndExplain(x, listener); } private: const MatcherInterface* const impl_; GTEST_DISALLOW_COPY_AND_ASSIGN_(MatcherInterfaceAdapter); }; struct AnyEq { template bool operator()(const A& a, const B& b) const { return a == b; } }; struct AnyNe { template bool operator()(const A& a, const B& b) const { return a != b; } }; struct AnyLt { template bool operator()(const A& a, const B& b) const { return a < b; } }; struct AnyGt { template bool operator()(const A& a, const B& b) const { return a > b; } }; struct AnyLe { template bool operator()(const A& a, const B& b) const { return a <= b; } }; struct AnyGe { template bool operator()(const A& a, const B& b) const { return a >= b; } }; // A match result listener that ignores the explanation. class DummyMatchResultListener : public MatchResultListener { public: DummyMatchResultListener() : MatchResultListener(nullptr) {} private: GTEST_DISALLOW_COPY_AND_ASSIGN_(DummyMatchResultListener); }; // A match result listener that forwards the explanation to a given // ostream. The difference between this and MatchResultListener is // that the former is concrete. class StreamMatchResultListener : public MatchResultListener { public: explicit StreamMatchResultListener(::std::ostream* os) : MatchResultListener(os) {} private: GTEST_DISALLOW_COPY_AND_ASSIGN_(StreamMatchResultListener); }; // An internal class for implementing Matcher, which will derive // from it. We put functionalities common to all Matcher // specializations here to avoid code duplication. template class MatcherBase { public: // Returns true if and only if the matcher matches x; also explains the // match result to 'listener'. bool MatchAndExplain(const T& x, MatchResultListener* listener) const { return impl_->MatchAndExplain(x, listener); } // Returns true if and only if this matcher matches x. bool Matches(const T& x) const { DummyMatchResultListener dummy; return MatchAndExplain(x, &dummy); } // Describes this matcher to an ostream. void DescribeTo(::std::ostream* os) const { impl_->DescribeTo(os); } // Describes the negation of this matcher to an ostream. void DescribeNegationTo(::std::ostream* os) const { impl_->DescribeNegationTo(os); } // Explains why x matches, or doesn't match, the matcher. void ExplainMatchResultTo(const T& x, ::std::ostream* os) const { StreamMatchResultListener listener(os); MatchAndExplain(x, &listener); } // Returns the describer for this matcher object; retains ownership // of the describer, which is only guaranteed to be alive when // this matcher object is alive. const MatcherDescriberInterface* GetDescriber() const { return impl_.get(); } protected: MatcherBase() {} // Constructs a matcher from its implementation. explicit MatcherBase(const MatcherInterface* impl) : impl_(impl) {} template explicit MatcherBase( const MatcherInterface* impl, typename std::enable_if::value>::type* = nullptr) : impl_(new internal::MatcherInterfaceAdapter(impl)) {} MatcherBase(const MatcherBase&) = default; MatcherBase& operator=(const MatcherBase&) = default; MatcherBase(MatcherBase&&) = default; MatcherBase& operator=(MatcherBase&&) = default; virtual ~MatcherBase() {} private: std::shared_ptr> impl_; }; } // namespace internal // A Matcher is a copyable and IMMUTABLE (except by assignment) // object that can check whether a value of type T matches. The // implementation of Matcher is just a std::shared_ptr to const // MatcherInterface. Don't inherit from Matcher! template class Matcher : public internal::MatcherBase { public: // Constructs a null matcher. Needed for storing Matcher objects in STL // containers. A default-constructed matcher is not yet initialized. You // cannot use it until a valid value has been assigned to it. explicit Matcher() {} // NOLINT // Constructs a matcher from its implementation. explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} template explicit Matcher( const MatcherInterface* impl, typename std::enable_if::value>::type* = nullptr) : internal::MatcherBase(impl) {} // Implicit constructor here allows people to write // EXPECT_CALL(foo, Bar(5)) instead of EXPECT_CALL(foo, Bar(Eq(5))) sometimes Matcher(T value); // NOLINT }; // The following two specializations allow the user to write str // instead of Eq(str) and "foo" instead of Eq("foo") when a std::string // matcher is expected. template <> class GTEST_API_ Matcher : public internal::MatcherBase { public: Matcher() {} explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} // Allows the user to write str instead of Eq(str) sometimes, where // str is a std::string object. Matcher(const std::string& s); // NOLINT // Allows the user to write "foo" instead of Eq("foo") sometimes. Matcher(const char* s); // NOLINT }; template <> class GTEST_API_ Matcher : public internal::MatcherBase { public: Matcher() {} explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} // Allows the user to write str instead of Eq(str) sometimes, where // str is a string object. Matcher(const std::string& s); // NOLINT // Allows the user to write "foo" instead of Eq("foo") sometimes. Matcher(const char* s); // NOLINT }; #if GTEST_HAS_ABSL // The following two specializations allow the user to write str // instead of Eq(str) and "foo" instead of Eq("foo") when a absl::string_view // matcher is expected. template <> class GTEST_API_ Matcher : public internal::MatcherBase { public: Matcher() {} explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} // Allows the user to write str instead of Eq(str) sometimes, where // str is a std::string object. Matcher(const std::string& s); // NOLINT // Allows the user to write "foo" instead of Eq("foo") sometimes. Matcher(const char* s); // NOLINT // Allows the user to pass absl::string_views directly. Matcher(absl::string_view s); // NOLINT }; template <> class GTEST_API_ Matcher : public internal::MatcherBase { public: Matcher() {} explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} explicit Matcher(const MatcherInterface* impl) : internal::MatcherBase(impl) {} // Allows the user to write str instead of Eq(str) sometimes, where // str is a std::string object. Matcher(const std::string& s); // NOLINT // Allows the user to write "foo" instead of Eq("foo") sometimes. Matcher(const char* s); // NOLINT // Allows the user to pass absl::string_views directly. Matcher(absl::string_view s); // NOLINT }; #endif // GTEST_HAS_ABSL // Prints a matcher in a human-readable format. template std::ostream& operator<<(std::ostream& os, const Matcher& matcher) { matcher.DescribeTo(&os); return os; } // The PolymorphicMatcher class template makes it easy to implement a // polymorphic matcher (i.e. a matcher that can match values of more // than one type, e.g. Eq(n) and NotNull()). // // To define a polymorphic matcher, a user should provide an Impl // class that has a DescribeTo() method and a DescribeNegationTo() // method, and define a member function (or member function template) // // bool MatchAndExplain(const Value& value, // MatchResultListener* listener) const; // // See the definition of NotNull() for a complete example. template class PolymorphicMatcher { public: explicit PolymorphicMatcher(const Impl& an_impl) : impl_(an_impl) {} // Returns a mutable reference to the underlying matcher // implementation object. Impl& mutable_impl() { return impl_; } // Returns an immutable reference to the underlying matcher // implementation object. const Impl& impl() const { return impl_; } template operator Matcher() const { return Matcher(new MonomorphicImpl(impl_)); } private: template class MonomorphicImpl : public MatcherInterface { public: explicit MonomorphicImpl(const Impl& impl) : impl_(impl) {} virtual void DescribeTo(::std::ostream* os) const { impl_.DescribeTo(os); } virtual void DescribeNegationTo(::std::ostream* os) const { impl_.DescribeNegationTo(os); } virtual bool MatchAndExplain(T x, MatchResultListener* listener) const { return impl_.MatchAndExplain(x, listener); } private: const Impl impl_; }; Impl impl_; }; // Creates a matcher from its implementation. // DEPRECATED: Especially in the generic code, prefer: // Matcher(new MyMatcherImpl(...)); // // MakeMatcher may create a Matcher that accepts its argument by value, which // leads to unnecessary copies & lack of support for non-copyable types. template inline Matcher MakeMatcher(const MatcherInterface* impl) { return Matcher(impl); } // Creates a polymorphic matcher from its implementation. This is // easier to use than the PolymorphicMatcher constructor as it // doesn't require you to explicitly write the template argument, e.g. // // MakePolymorphicMatcher(foo); // vs // PolymorphicMatcher(foo); template inline PolymorphicMatcher MakePolymorphicMatcher(const Impl& impl) { return PolymorphicMatcher(impl); } namespace internal { // Implements a matcher that compares a given value with a // pre-supplied value using one of the ==, <=, <, etc, operators. The // two values being compared don't have to have the same type. // // The matcher defined here is polymorphic (for example, Eq(5) can be // used to match an int, a short, a double, etc). Therefore we use // a template type conversion operator in the implementation. // // The following template definition assumes that the Rhs parameter is // a "bare" type (i.e. neither 'const T' nor 'T&'). template class ComparisonBase { public: explicit ComparisonBase(const Rhs& rhs) : rhs_(rhs) {} template operator Matcher() const { return Matcher(new Impl(rhs_)); } private: template static const T& Unwrap(const T& v) { return v; } template static const T& Unwrap(std::reference_wrapper v) { return v; } template class Impl : public MatcherInterface { public: explicit Impl(const Rhs& rhs) : rhs_(rhs) {} bool MatchAndExplain(Lhs lhs, MatchResultListener* /* listener */) const override { return Op()(lhs, Unwrap(rhs_)); } void DescribeTo(::std::ostream* os) const override { *os << D::Desc() << " "; UniversalPrint(Unwrap(rhs_), os); } void DescribeNegationTo(::std::ostream* os) const override { *os << D::NegatedDesc() << " "; UniversalPrint(Unwrap(rhs_), os); } private: Rhs rhs_; }; Rhs rhs_; }; template class EqMatcher : public ComparisonBase, Rhs, AnyEq> { public: explicit EqMatcher(const Rhs& rhs) : ComparisonBase, Rhs, AnyEq>(rhs) { } static const char* Desc() { return "is equal to"; } static const char* NegatedDesc() { return "isn't equal to"; } }; template class NeMatcher : public ComparisonBase, Rhs, AnyNe> { public: explicit NeMatcher(const Rhs& rhs) : ComparisonBase, Rhs, AnyNe>(rhs) { } static const char* Desc() { return "isn't equal to"; } static const char* NegatedDesc() { return "is equal to"; } }; template class LtMatcher : public ComparisonBase, Rhs, AnyLt> { public: explicit LtMatcher(const Rhs& rhs) : ComparisonBase, Rhs, AnyLt>(rhs) { } static const char* Desc() { return "is <"; } static const char* NegatedDesc() { return "isn't <"; } }; template class GtMatcher : public ComparisonBase, Rhs, AnyGt> { public: explicit GtMatcher(const Rhs& rhs) : ComparisonBase, Rhs, AnyGt>(rhs) { } static const char* Desc() { return "is >"; } static const char* NegatedDesc() { return "isn't >"; } }; template class LeMatcher : public ComparisonBase, Rhs, AnyLe> { public: explicit LeMatcher(const Rhs& rhs) : ComparisonBase, Rhs, AnyLe>(rhs) { } static const char* Desc() { return "is <="; } static const char* NegatedDesc() { return "isn't <="; } }; template class GeMatcher : public ComparisonBase, Rhs, AnyGe> { public: explicit GeMatcher(const Rhs& rhs) : ComparisonBase, Rhs, AnyGe>(rhs) { } static const char* Desc() { return "is >="; } static const char* NegatedDesc() { return "isn't >="; } }; // Implements polymorphic matchers MatchesRegex(regex) and // ContainsRegex(regex), which can be used as a Matcher as long as // T can be converted to a string. class MatchesRegexMatcher { public: MatchesRegexMatcher(const RE* regex, bool full_match) : regex_(regex), full_match_(full_match) {} #if GTEST_HAS_ABSL bool MatchAndExplain(const absl::string_view& s, MatchResultListener* listener) const { return MatchAndExplain(std::string(s), listener); } #endif // GTEST_HAS_ABSL // Accepts pointer types, particularly: // const char* // char* // const wchar_t* // wchar_t* template bool MatchAndExplain(CharType* s, MatchResultListener* listener) const { return s != nullptr && MatchAndExplain(std::string(s), listener); } // Matches anything that can convert to std::string. // // This is a template, not just a plain function with const std::string&, // because absl::string_view has some interfering non-explicit constructors. template bool MatchAndExplain(const MatcheeStringType& s, MatchResultListener* /* listener */) const { const std::string& s2(s); return full_match_ ? RE::FullMatch(s2, *regex_) : RE::PartialMatch(s2, *regex_); } void DescribeTo(::std::ostream* os) const { *os << (full_match_ ? "matches" : "contains") << " regular expression "; UniversalPrinter::Print(regex_->pattern(), os); } void DescribeNegationTo(::std::ostream* os) const { *os << "doesn't " << (full_match_ ? "match" : "contain") << " regular expression "; UniversalPrinter::Print(regex_->pattern(), os); } private: const std::shared_ptr regex_; const bool full_match_; }; } // namespace internal // Matches a string that fully matches regular expression 'regex'. // The matcher takes ownership of 'regex'. inline PolymorphicMatcher MatchesRegex( const internal::RE* regex) { return MakePolymorphicMatcher(internal::MatchesRegexMatcher(regex, true)); } inline PolymorphicMatcher MatchesRegex( const std::string& regex) { return MatchesRegex(new internal::RE(regex)); } // Matches a string that contains regular expression 'regex'. // The matcher takes ownership of 'regex'. inline PolymorphicMatcher ContainsRegex( const internal::RE* regex) { return MakePolymorphicMatcher(internal::MatchesRegexMatcher(regex, false)); } inline PolymorphicMatcher ContainsRegex( const std::string& regex) { return ContainsRegex(new internal::RE(regex)); } // Creates a polymorphic matcher that matches anything equal to x. // Note: if the parameter of Eq() were declared as const T&, Eq("foo") // wouldn't compile. template inline internal::EqMatcher Eq(T x) { return internal::EqMatcher(x); } // Constructs a Matcher from a 'value' of type T. The constructed // matcher matches any value that's equal to 'value'. template Matcher::Matcher(T value) { *this = Eq(value); } // Creates a monomorphic matcher that matches anything with type Lhs // and equal to rhs. A user may need to use this instead of Eq(...) // in order to resolve an overloading ambiguity. // // TypedEq(x) is just a convenient short-hand for Matcher(Eq(x)) // or Matcher(x), but more readable than the latter. // // We could define similar monomorphic matchers for other comparison // operations (e.g. TypedLt, TypedGe, and etc), but decided not to do // it yet as those are used much less than Eq() in practice. A user // can always write Matcher(Lt(5)) to be explicit about the type, // for example. template inline Matcher TypedEq(const Rhs& rhs) { return Eq(rhs); } // Creates a polymorphic matcher that matches anything >= x. template inline internal::GeMatcher Ge(Rhs x) { return internal::GeMatcher(x); } // Creates a polymorphic matcher that matches anything > x. template inline internal::GtMatcher Gt(Rhs x) { return internal::GtMatcher(x); } // Creates a polymorphic matcher that matches anything <= x. template inline internal::LeMatcher Le(Rhs x) { return internal::LeMatcher(x); } // Creates a polymorphic matcher that matches anything < x. template inline internal::LtMatcher Lt(Rhs x) { return internal::LtMatcher(x); } // Creates a polymorphic matcher that matches anything != x. template inline internal::NeMatcher Ne(Rhs x) { return internal::NeMatcher(x); } } // namespace testing GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 5046 #endif // GTEST_INCLUDE_GTEST_GTEST_MATCHERS_H_ #include #include namespace testing { namespace internal { GTEST_DECLARE_string_(internal_run_death_test); // Names of the flags (needed for parsing Google Test flags). const char kDeathTestStyleFlag[] = "death_test_style"; const char kDeathTestUseFork[] = "death_test_use_fork"; const char kInternalRunDeathTestFlag[] = "internal_run_death_test"; #if GTEST_HAS_DEATH_TEST GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) // DeathTest is a class that hides much of the complexity of the // GTEST_DEATH_TEST_ macro. It is abstract; its static Create method // returns a concrete class that depends on the prevailing death test // style, as defined by the --gtest_death_test_style and/or // --gtest_internal_run_death_test flags. // In describing the results of death tests, these terms are used with // the corresponding definitions: // // exit status: The integer exit information in the format specified // by wait(2) // exit code: The integer code passed to exit(3), _exit(2), or // returned from main() class GTEST_API_ DeathTest { public: // Create returns false if there was an error determining the // appropriate action to take for the current death test; for example, // if the gtest_death_test_style flag is set to an invalid value. // The LastMessage method will return a more detailed message in that // case. Otherwise, the DeathTest pointer pointed to by the "test" // argument is set. If the death test should be skipped, the pointer // is set to NULL; otherwise, it is set to the address of a new concrete // DeathTest object that controls the execution of the current test. static bool Create(const char* statement, Matcher matcher, const char* file, int line, DeathTest** test); DeathTest(); virtual ~DeathTest() { } // A helper class that aborts a death test when it's deleted. class ReturnSentinel { public: explicit ReturnSentinel(DeathTest* test) : test_(test) { } ~ReturnSentinel() { test_->Abort(TEST_ENCOUNTERED_RETURN_STATEMENT); } private: DeathTest* const test_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ReturnSentinel); } GTEST_ATTRIBUTE_UNUSED_; // An enumeration of possible roles that may be taken when a death // test is encountered. EXECUTE means that the death test logic should // be executed immediately. OVERSEE means that the program should prepare // the appropriate environment for a child process to execute the death // test, then wait for it to complete. enum TestRole { OVERSEE_TEST, EXECUTE_TEST }; // An enumeration of the three reasons that a test might be aborted. enum AbortReason { TEST_ENCOUNTERED_RETURN_STATEMENT, TEST_THREW_EXCEPTION, TEST_DID_NOT_DIE }; // Assumes one of the above roles. virtual TestRole AssumeRole() = 0; // Waits for the death test to finish and returns its status. virtual int Wait() = 0; // Returns true if the death test passed; that is, the test process // exited during the test, its exit status matches a user-supplied // predicate, and its stderr output matches a user-supplied regular // expression. // The user-supplied predicate may be a macro expression rather // than a function pointer or functor, or else Wait and Passed could // be combined. virtual bool Passed(bool exit_status_ok) = 0; // Signals that the death test did not die as expected. virtual void Abort(AbortReason reason) = 0; // Returns a human-readable outcome message regarding the outcome of // the last death test. static const char* LastMessage(); static void set_last_death_test_message(const std::string& message); private: // A string containing a description of the outcome of the last death test. static std::string last_death_test_message_; GTEST_DISALLOW_COPY_AND_ASSIGN_(DeathTest); }; GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 // Factory interface for death tests. May be mocked out for testing. class DeathTestFactory { public: virtual ~DeathTestFactory() { } virtual bool Create(const char* statement, Matcher matcher, const char* file, int line, DeathTest** test) = 0; }; // A concrete DeathTestFactory implementation for normal use. class DefaultDeathTestFactory : public DeathTestFactory { public: bool Create(const char* statement, Matcher matcher, const char* file, int line, DeathTest** test) override; }; // Returns true if exit_status describes a process that was terminated // by a signal, or exited normally with a nonzero exit code. GTEST_API_ bool ExitedUnsuccessfully(int exit_status); // A string passed to EXPECT_DEATH (etc.) is caught by one of these overloads // and interpreted as a regex (rather than an Eq matcher) for legacy // compatibility. inline Matcher MakeDeathTestMatcher( ::testing::internal::RE regex) { return ContainsRegex(regex.pattern()); } inline Matcher MakeDeathTestMatcher(const char* regex) { return ContainsRegex(regex); } inline Matcher MakeDeathTestMatcher( const ::std::string& regex) { return ContainsRegex(regex); } // If a Matcher is passed to EXPECT_DEATH (etc.), it's // used directly. inline Matcher MakeDeathTestMatcher( Matcher matcher) { return matcher; } // Traps C++ exceptions escaping statement and reports them as test // failures. Note that trapping SEH exceptions is not implemented here. # if GTEST_HAS_EXCEPTIONS # define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ try { \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ } catch (const ::std::exception& gtest_exception) { \ fprintf(\ stderr, \ "\n%s: Caught std::exception-derived exception escaping the " \ "death test statement. Exception message: %s\n", \ ::testing::internal::FormatFileLocation(__FILE__, __LINE__).c_str(), \ gtest_exception.what()); \ fflush(stderr); \ death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ } catch (...) { \ death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ } # else # define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) # endif // This macro is for implementing ASSERT_DEATH*, EXPECT_DEATH*, // ASSERT_EXIT*, and EXPECT_EXIT*. #define GTEST_DEATH_TEST_(statement, predicate, regex_or_matcher, fail) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::AlwaysTrue()) { \ ::testing::internal::DeathTest* gtest_dt; \ if (!::testing::internal::DeathTest::Create( \ #statement, \ ::testing::internal::MakeDeathTestMatcher(regex_or_matcher), \ __FILE__, __LINE__, >est_dt)) { \ goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ } \ if (gtest_dt != nullptr) { \ std::unique_ptr< ::testing::internal::DeathTest> gtest_dt_ptr(gtest_dt); \ switch (gtest_dt->AssumeRole()) { \ case ::testing::internal::DeathTest::OVERSEE_TEST: \ if (!gtest_dt->Passed(predicate(gtest_dt->Wait()))) { \ goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ } \ break; \ case ::testing::internal::DeathTest::EXECUTE_TEST: { \ ::testing::internal::DeathTest::ReturnSentinel gtest_sentinel( \ gtest_dt); \ GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, gtest_dt); \ gtest_dt->Abort(::testing::internal::DeathTest::TEST_DID_NOT_DIE); \ break; \ } \ default: \ break; \ } \ } \ } else \ GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__) \ : fail(::testing::internal::DeathTest::LastMessage()) // The symbol "fail" here expands to something into which a message // can be streamed. // This macro is for implementing ASSERT/EXPECT_DEBUG_DEATH when compiled in // NDEBUG mode. In this case we need the statements to be executed and the macro // must accept a streamed message even though the message is never printed. // The regex object is not evaluated, but it is used to prevent "unused" // warnings and to avoid an expression that doesn't compile in debug mode. #define GTEST_EXECUTE_STATEMENT_(statement, regex_or_matcher) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::AlwaysTrue()) { \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ } else if (!::testing::internal::AlwaysTrue()) { \ ::testing::internal::MakeDeathTestMatcher(regex_or_matcher); \ } else \ ::testing::Message() // A class representing the parsed contents of the // --gtest_internal_run_death_test flag, as it existed when // RUN_ALL_TESTS was called. class InternalRunDeathTestFlag { public: InternalRunDeathTestFlag(const std::string& a_file, int a_line, int an_index, int a_write_fd) : file_(a_file), line_(a_line), index_(an_index), write_fd_(a_write_fd) {} ~InternalRunDeathTestFlag() { if (write_fd_ >= 0) posix::Close(write_fd_); } const std::string& file() const { return file_; } int line() const { return line_; } int index() const { return index_; } int write_fd() const { return write_fd_; } private: std::string file_; int line_; int index_; int write_fd_; GTEST_DISALLOW_COPY_AND_ASSIGN_(InternalRunDeathTestFlag); }; // Returns a newly created InternalRunDeathTestFlag object with fields // initialized from the GTEST_FLAG(internal_run_death_test) flag if // the flag is specified; otherwise returns NULL. InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag(); #endif // GTEST_HAS_DEATH_TEST } // namespace internal } // namespace testing #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ namespace testing { // This flag controls the style of death tests. Valid values are "threadsafe", // meaning that the death test child process will re-execute the test binary // from the start, running only a single death test, or "fast", // meaning that the child process will execute the test logic immediately // after forking. GTEST_DECLARE_string_(death_test_style); #if GTEST_HAS_DEATH_TEST namespace internal { // Returns a Boolean value indicating whether the caller is currently // executing in the context of the death test child process. Tools such as // Valgrind heap checkers may need this to modify their behavior in death // tests. IMPORTANT: This is an internal utility. Using it may break the // implementation of death tests. User code MUST NOT use it. GTEST_API_ bool InDeathTestChild(); } // namespace internal // The following macros are useful for writing death tests. // Here's what happens when an ASSERT_DEATH* or EXPECT_DEATH* is // executed: // // 1. It generates a warning if there is more than one active // thread. This is because it's safe to fork() or clone() only // when there is a single thread. // // 2. The parent process clone()s a sub-process and runs the death // test in it; the sub-process exits with code 0 at the end of the // death test, if it hasn't exited already. // // 3. The parent process waits for the sub-process to terminate. // // 4. The parent process checks the exit code and error message of // the sub-process. // // Examples: // // ASSERT_DEATH(server.SendMessage(56, "Hello"), "Invalid port number"); // for (int i = 0; i < 5; i++) { // EXPECT_DEATH(server.ProcessRequest(i), // "Invalid request .* in ProcessRequest()") // << "Failed to die on request " << i; // } // // ASSERT_EXIT(server.ExitNow(), ::testing::ExitedWithCode(0), "Exiting"); // // bool KilledBySIGHUP(int exit_code) { // return WIFSIGNALED(exit_code) && WTERMSIG(exit_code) == SIGHUP; // } // // ASSERT_EXIT(client.HangUpServer(), KilledBySIGHUP, "Hanging up!"); // // On the regular expressions used in death tests: // // GOOGLETEST_CM0005 DO NOT DELETE // On POSIX-compliant systems (*nix), we use the library, // which uses the POSIX extended regex syntax. // // On other platforms (e.g. Windows or Mac), we only support a simple regex // syntax implemented as part of Google Test. This limited // implementation should be enough most of the time when writing // death tests; though it lacks many features you can find in PCRE // or POSIX extended regex syntax. For example, we don't support // union ("x|y"), grouping ("(xy)"), brackets ("[xy]"), and // repetition count ("x{5,7}"), among others. // // Below is the syntax that we do support. We chose it to be a // subset of both PCRE and POSIX extended regex, so it's easy to // learn wherever you come from. In the following: 'A' denotes a // literal character, period (.), or a single \\ escape sequence; // 'x' and 'y' denote regular expressions; 'm' and 'n' are for // natural numbers. // // c matches any literal character c // \\d matches any decimal digit // \\D matches any character that's not a decimal digit // \\f matches \f // \\n matches \n // \\r matches \r // \\s matches any ASCII whitespace, including \n // \\S matches any character that's not a whitespace // \\t matches \t // \\v matches \v // \\w matches any letter, _, or decimal digit // \\W matches any character that \\w doesn't match // \\c matches any literal character c, which must be a punctuation // . matches any single character except \n // A? matches 0 or 1 occurrences of A // A* matches 0 or many occurrences of A // A+ matches 1 or many occurrences of A // ^ matches the beginning of a string (not that of each line) // $ matches the end of a string (not that of each line) // xy matches x followed by y // // If you accidentally use PCRE or POSIX extended regex features // not implemented by us, you will get a run-time failure. In that // case, please try to rewrite your regular expression within the // above syntax. // // This implementation is *not* meant to be as highly tuned or robust // as a compiled regex library, but should perform well enough for a // death test, which already incurs significant overhead by launching // a child process. // // Known caveats: // // A "threadsafe" style death test obtains the path to the test // program from argv[0] and re-executes it in the sub-process. For // simplicity, the current implementation doesn't search the PATH // when launching the sub-process. This means that the user must // invoke the test program via a path that contains at least one // path separator (e.g. path/to/foo_test and // /absolute/path/to/bar_test are fine, but foo_test is not). This // is rarely a problem as people usually don't put the test binary // directory in PATH. // // Asserts that a given statement causes the program to exit, with an // integer exit status that satisfies predicate, and emitting error output // that matches regex. # define ASSERT_EXIT(statement, predicate, regex) \ GTEST_DEATH_TEST_(statement, predicate, regex, GTEST_FATAL_FAILURE_) // Like ASSERT_EXIT, but continues on to successive tests in the // test suite, if any: # define EXPECT_EXIT(statement, predicate, regex) \ GTEST_DEATH_TEST_(statement, predicate, regex, GTEST_NONFATAL_FAILURE_) // Asserts that a given statement causes the program to exit, either by // explicitly exiting with a nonzero exit code or being killed by a // signal, and emitting error output that matches regex. # define ASSERT_DEATH(statement, regex) \ ASSERT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, regex) // Like ASSERT_DEATH, but continues on to successive tests in the // test suite, if any: # define EXPECT_DEATH(statement, regex) \ EXPECT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, regex) // Two predicate classes that can be used in {ASSERT,EXPECT}_EXIT*: // Tests that an exit code describes a normal exit with a given exit code. class GTEST_API_ ExitedWithCode { public: explicit ExitedWithCode(int exit_code); bool operator()(int exit_status) const; private: // No implementation - assignment is unsupported. void operator=(const ExitedWithCode& other); const int exit_code_; }; # if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA // Tests that an exit code describes an exit due to termination by a // given signal. // GOOGLETEST_CM0006 DO NOT DELETE class GTEST_API_ KilledBySignal { public: explicit KilledBySignal(int signum); bool operator()(int exit_status) const; private: const int signum_; }; # endif // !GTEST_OS_WINDOWS // EXPECT_DEBUG_DEATH asserts that the given statements die in debug mode. // The death testing framework causes this to have interesting semantics, // since the sideeffects of the call are only visible in opt mode, and not // in debug mode. // // In practice, this can be used to test functions that utilize the // LOG(DFATAL) macro using the following style: // // int DieInDebugOr12(int* sideeffect) { // if (sideeffect) { // *sideeffect = 12; // } // LOG(DFATAL) << "death"; // return 12; // } // // TEST(TestSuite, TestDieOr12WorksInDgbAndOpt) { // int sideeffect = 0; // // Only asserts in dbg. // EXPECT_DEBUG_DEATH(DieInDebugOr12(&sideeffect), "death"); // // #ifdef NDEBUG // // opt-mode has sideeffect visible. // EXPECT_EQ(12, sideeffect); // #else // // dbg-mode no visible sideeffect. // EXPECT_EQ(0, sideeffect); // #endif // } // // This will assert that DieInDebugReturn12InOpt() crashes in debug // mode, usually due to a DCHECK or LOG(DFATAL), but returns the // appropriate fallback value (12 in this case) in opt mode. If you // need to test that a function has appropriate side-effects in opt // mode, include assertions against the side-effects. A general // pattern for this is: // // EXPECT_DEBUG_DEATH({ // // Side-effects here will have an effect after this statement in // // opt mode, but none in debug mode. // EXPECT_EQ(12, DieInDebugOr12(&sideeffect)); // }, "death"); // # ifdef NDEBUG # define EXPECT_DEBUG_DEATH(statement, regex) \ GTEST_EXECUTE_STATEMENT_(statement, regex) # define ASSERT_DEBUG_DEATH(statement, regex) \ GTEST_EXECUTE_STATEMENT_(statement, regex) # else # define EXPECT_DEBUG_DEATH(statement, regex) \ EXPECT_DEATH(statement, regex) # define ASSERT_DEBUG_DEATH(statement, regex) \ ASSERT_DEATH(statement, regex) # endif // NDEBUG for EXPECT_DEBUG_DEATH #endif // GTEST_HAS_DEATH_TEST // This macro is used for implementing macros such as // EXPECT_DEATH_IF_SUPPORTED and ASSERT_DEATH_IF_SUPPORTED on systems where // death tests are not supported. Those macros must compile on such systems // if and only if EXPECT_DEATH and ASSERT_DEATH compile with the same parameters // on systems that support death tests. This allows one to write such a macro on // a system that does not support death tests and be sure that it will compile // on a death-test supporting system. It is exposed publicly so that systems // that have death-tests with stricter requirements than GTEST_HAS_DEATH_TEST // can write their own equivalent of EXPECT_DEATH_IF_SUPPORTED and // ASSERT_DEATH_IF_SUPPORTED. // // Parameters: // statement - A statement that a macro such as EXPECT_DEATH would test // for program termination. This macro has to make sure this // statement is compiled but not executed, to ensure that // EXPECT_DEATH_IF_SUPPORTED compiles with a certain // parameter if and only if EXPECT_DEATH compiles with it. // regex - A regex that a macro such as EXPECT_DEATH would use to test // the output of statement. This parameter has to be // compiled but not evaluated by this macro, to ensure that // this macro only accepts expressions that a macro such as // EXPECT_DEATH would accept. // terminator - Must be an empty statement for EXPECT_DEATH_IF_SUPPORTED // and a return statement for ASSERT_DEATH_IF_SUPPORTED. // This ensures that ASSERT_DEATH_IF_SUPPORTED will not // compile inside functions where ASSERT_DEATH doesn't // compile. // // The branch that has an always false condition is used to ensure that // statement and regex are compiled (and thus syntactically correct) but // never executed. The unreachable code macro protects the terminator // statement from generating an 'unreachable code' warning in case // statement unconditionally returns or throws. The Message constructor at // the end allows the syntax of streaming additional messages into the // macro, for compilational compatibility with EXPECT_DEATH/ASSERT_DEATH. # define GTEST_UNSUPPORTED_DEATH_TEST(statement, regex, terminator) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (::testing::internal::AlwaysTrue()) { \ GTEST_LOG_(WARNING) \ << "Death tests are not supported on this platform.\n" \ << "Statement '" #statement "' cannot be verified."; \ } else if (::testing::internal::AlwaysFalse()) { \ ::testing::internal::RE::PartialMatch(".*", (regex)); \ GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ terminator; \ } else \ ::testing::Message() // EXPECT_DEATH_IF_SUPPORTED(statement, regex) and // ASSERT_DEATH_IF_SUPPORTED(statement, regex) expand to real death tests if // death tests are supported; otherwise they just issue a warning. This is // useful when you are combining death test assertions with normal test // assertions in one test. #if GTEST_HAS_DEATH_TEST # define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ EXPECT_DEATH(statement, regex) # define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ ASSERT_DEATH(statement, regex) #else # define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ GTEST_UNSUPPORTED_DEATH_TEST(statement, regex, ) # define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ GTEST_UNSUPPORTED_DEATH_TEST(statement, regex, return) #endif } // namespace testing #endif // GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ // Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Macros and functions for implementing parameterized tests // in Google C++ Testing and Mocking Framework (Google Test) // // This file is generated by a SCRIPT. DO NOT EDIT BY HAND! // // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ #define GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ // Value-parameterized tests allow you to test your code with different // parameters without writing multiple copies of the same test. // // Here is how you use value-parameterized tests: #if 0 // To write value-parameterized tests, first you should define a fixture // class. It is usually derived from testing::TestWithParam (see below for // another inheritance scheme that's sometimes useful in more complicated // class hierarchies), where the type of your parameter values. // TestWithParam is itself derived from testing::Test. T can be any // copyable type. If it's a raw pointer, you are responsible for managing the // lifespan of the pointed values. class FooTest : public ::testing::TestWithParam { // You can implement all the usual class fixture members here. }; // Then, use the TEST_P macro to define as many parameterized tests // for this fixture as you want. The _P suffix is for "parameterized" // or "pattern", whichever you prefer to think. TEST_P(FooTest, DoesBlah) { // Inside a test, access the test parameter with the GetParam() method // of the TestWithParam class: EXPECT_TRUE(foo.Blah(GetParam())); ... } TEST_P(FooTest, HasBlahBlah) { ... } // Finally, you can use INSTANTIATE_TEST_SUITE_P to instantiate the test // case with any set of parameters you want. Google Test defines a number // of functions for generating test parameters. They return what we call // (surprise!) parameter generators. Here is a summary of them, which // are all in the testing namespace: // // // Range(begin, end [, step]) - Yields values {begin, begin+step, // begin+step+step, ...}. The values do not // include end. step defaults to 1. // Values(v1, v2, ..., vN) - Yields values {v1, v2, ..., vN}. // ValuesIn(container) - Yields values from a C-style array, an STL // ValuesIn(begin,end) container, or an iterator range [begin, end). // Bool() - Yields sequence {false, true}. // Combine(g1, g2, ..., gN) - Yields all combinations (the Cartesian product // for the math savvy) of the values generated // by the N generators. // // For more details, see comments at the definitions of these functions below // in this file. // // The following statement will instantiate tests from the FooTest test suite // each with parameter values "meeny", "miny", and "moe". INSTANTIATE_TEST_SUITE_P(InstantiationName, FooTest, Values("meeny", "miny", "moe")); // To distinguish different instances of the pattern, (yes, you // can instantiate it more than once) the first argument to the // INSTANTIATE_TEST_SUITE_P macro is a prefix that will be added to the // actual test suite name. Remember to pick unique prefixes for different // instantiations. The tests from the instantiation above will have // these names: // // * InstantiationName/FooTest.DoesBlah/0 for "meeny" // * InstantiationName/FooTest.DoesBlah/1 for "miny" // * InstantiationName/FooTest.DoesBlah/2 for "moe" // * InstantiationName/FooTest.HasBlahBlah/0 for "meeny" // * InstantiationName/FooTest.HasBlahBlah/1 for "miny" // * InstantiationName/FooTest.HasBlahBlah/2 for "moe" // // You can use these names in --gtest_filter. // // This statement will instantiate all tests from FooTest again, each // with parameter values "cat" and "dog": const char* pets[] = {"cat", "dog"}; INSTANTIATE_TEST_SUITE_P(AnotherInstantiationName, FooTest, ValuesIn(pets)); // The tests from the instantiation above will have these names: // // * AnotherInstantiationName/FooTest.DoesBlah/0 for "cat" // * AnotherInstantiationName/FooTest.DoesBlah/1 for "dog" // * AnotherInstantiationName/FooTest.HasBlahBlah/0 for "cat" // * AnotherInstantiationName/FooTest.HasBlahBlah/1 for "dog" // // Please note that INSTANTIATE_TEST_SUITE_P will instantiate all tests // in the given test suite, whether their definitions come before or // AFTER the INSTANTIATE_TEST_SUITE_P statement. // // Please also note that generator expressions (including parameters to the // generators) are evaluated in InitGoogleTest(), after main() has started. // This allows the user on one hand, to adjust generator parameters in order // to dynamically determine a set of tests to run and on the other hand, // give the user a chance to inspect the generated tests with Google Test // reflection API before RUN_ALL_TESTS() is executed. // // You can see samples/sample7_unittest.cc and samples/sample8_unittest.cc // for more examples. // // In the future, we plan to publish the API for defining new parameter // generators. But for now this interface remains part of the internal // implementation and is subject to change. // // // A parameterized test fixture must be derived from testing::Test and from // testing::WithParamInterface, where T is the type of the parameter // values. Inheriting from TestWithParam satisfies that requirement because // TestWithParam inherits from both Test and WithParamInterface. In more // complicated hierarchies, however, it is occasionally useful to inherit // separately from Test and WithParamInterface. For example: class BaseTest : public ::testing::Test { // You can inherit all the usual members for a non-parameterized test // fixture here. }; class DerivedTest : public BaseTest, public ::testing::WithParamInterface { // The usual test fixture members go here too. }; TEST_F(BaseTest, HasFoo) { // This is an ordinary non-parameterized test. } TEST_P(DerivedTest, DoesBlah) { // GetParam works just the same here as if you inherit from TestWithParam. EXPECT_TRUE(foo.Blah(GetParam())); } #endif // 0 #include #include // Copyright 2008 Google Inc. // All Rights Reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Type and function utilities for implementing parameterized tests. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ #define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ #include #include #include #include #include #include #include #include namespace testing { // Input to a parameterized test name generator, describing a test parameter. // Consists of the parameter value and the integer parameter index. template struct TestParamInfo { TestParamInfo(const ParamType& a_param, size_t an_index) : param(a_param), index(an_index) {} ParamType param; size_t index; }; // A builtin parameterized test name generator which returns the result of // testing::PrintToString. struct PrintToStringParamName { template std::string operator()(const TestParamInfo& info) const { return PrintToString(info.param); } }; namespace internal { // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // Utility Functions // Outputs a message explaining invalid registration of different // fixture class for the same test suite. This may happen when // TEST_P macro is used to define two tests with the same name // but in different namespaces. GTEST_API_ void ReportInvalidTestSuiteType(const char* test_suite_name, CodeLocation code_location); template class ParamGeneratorInterface; template class ParamGenerator; // Interface for iterating over elements provided by an implementation // of ParamGeneratorInterface. template class ParamIteratorInterface { public: virtual ~ParamIteratorInterface() {} // A pointer to the base generator instance. // Used only for the purposes of iterator comparison // to make sure that two iterators belong to the same generator. virtual const ParamGeneratorInterface* BaseGenerator() const = 0; // Advances iterator to point to the next element // provided by the generator. The caller is responsible // for not calling Advance() on an iterator equal to // BaseGenerator()->End(). virtual void Advance() = 0; // Clones the iterator object. Used for implementing copy semantics // of ParamIterator. virtual ParamIteratorInterface* Clone() const = 0; // Dereferences the current iterator and provides (read-only) access // to the pointed value. It is the caller's responsibility not to call // Current() on an iterator equal to BaseGenerator()->End(). // Used for implementing ParamGenerator::operator*(). virtual const T* Current() const = 0; // Determines whether the given iterator and other point to the same // element in the sequence generated by the generator. // Used for implementing ParamGenerator::operator==(). virtual bool Equals(const ParamIteratorInterface& other) const = 0; }; // Class iterating over elements provided by an implementation of // ParamGeneratorInterface. It wraps ParamIteratorInterface // and implements the const forward iterator concept. template class ParamIterator { public: typedef T value_type; typedef const T& reference; typedef ptrdiff_t difference_type; // ParamIterator assumes ownership of the impl_ pointer. ParamIterator(const ParamIterator& other) : impl_(other.impl_->Clone()) {} ParamIterator& operator=(const ParamIterator& other) { if (this != &other) impl_.reset(other.impl_->Clone()); return *this; } const T& operator*() const { return *impl_->Current(); } const T* operator->() const { return impl_->Current(); } // Prefix version of operator++. ParamIterator& operator++() { impl_->Advance(); return *this; } // Postfix version of operator++. ParamIterator operator++(int /*unused*/) { ParamIteratorInterface* clone = impl_->Clone(); impl_->Advance(); return ParamIterator(clone); } bool operator==(const ParamIterator& other) const { return impl_.get() == other.impl_.get() || impl_->Equals(*other.impl_); } bool operator!=(const ParamIterator& other) const { return !(*this == other); } private: friend class ParamGenerator; explicit ParamIterator(ParamIteratorInterface* impl) : impl_(impl) {} std::unique_ptr > impl_; }; // ParamGeneratorInterface is the binary interface to access generators // defined in other translation units. template class ParamGeneratorInterface { public: typedef T ParamType; virtual ~ParamGeneratorInterface() {} // Generator interface definition virtual ParamIteratorInterface* Begin() const = 0; virtual ParamIteratorInterface* End() const = 0; }; // Wraps ParamGeneratorInterface and provides general generator syntax // compatible with the STL Container concept. // This class implements copy initialization semantics and the contained // ParamGeneratorInterface instance is shared among all copies // of the original object. This is possible because that instance is immutable. template class ParamGenerator { public: typedef ParamIterator iterator; explicit ParamGenerator(ParamGeneratorInterface* impl) : impl_(impl) {} ParamGenerator(const ParamGenerator& other) : impl_(other.impl_) {} ParamGenerator& operator=(const ParamGenerator& other) { impl_ = other.impl_; return *this; } iterator begin() const { return iterator(impl_->Begin()); } iterator end() const { return iterator(impl_->End()); } private: std::shared_ptr > impl_; }; // Generates values from a range of two comparable values. Can be used to // generate sequences of user-defined types that implement operator+() and // operator<(). // This class is used in the Range() function. template class RangeGenerator : public ParamGeneratorInterface { public: RangeGenerator(T begin, T end, IncrementT step) : begin_(begin), end_(end), step_(step), end_index_(CalculateEndIndex(begin, end, step)) {} ~RangeGenerator() override {} ParamIteratorInterface* Begin() const override { return new Iterator(this, begin_, 0, step_); } ParamIteratorInterface* End() const override { return new Iterator(this, end_, end_index_, step_); } private: class Iterator : public ParamIteratorInterface { public: Iterator(const ParamGeneratorInterface* base, T value, int index, IncrementT step) : base_(base), value_(value), index_(index), step_(step) {} ~Iterator() override {} const ParamGeneratorInterface* BaseGenerator() const override { return base_; } void Advance() override { value_ = static_cast(value_ + step_); index_++; } ParamIteratorInterface* Clone() const override { return new Iterator(*this); } const T* Current() const override { return &value_; } bool Equals(const ParamIteratorInterface& other) const override { // Having the same base generator guarantees that the other // iterator is of the same type and we can downcast. GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) << "The program attempted to compare iterators " << "from different generators." << std::endl; const int other_index = CheckedDowncastToActualType(&other)->index_; return index_ == other_index; } private: Iterator(const Iterator& other) : ParamIteratorInterface(), base_(other.base_), value_(other.value_), index_(other.index_), step_(other.step_) {} // No implementation - assignment is unsupported. void operator=(const Iterator& other); const ParamGeneratorInterface* const base_; T value_; int index_; const IncrementT step_; }; // class RangeGenerator::Iterator static int CalculateEndIndex(const T& begin, const T& end, const IncrementT& step) { int end_index = 0; for (T i = begin; i < end; i = static_cast(i + step)) end_index++; return end_index; } // No implementation - assignment is unsupported. void operator=(const RangeGenerator& other); const T begin_; const T end_; const IncrementT step_; // The index for the end() iterator. All the elements in the generated // sequence are indexed (0-based) to aid iterator comparison. const int end_index_; }; // class RangeGenerator // Generates values from a pair of STL-style iterators. Used in the // ValuesIn() function. The elements are copied from the source range // since the source can be located on the stack, and the generator // is likely to persist beyond that stack frame. template class ValuesInIteratorRangeGenerator : public ParamGeneratorInterface { public: template ValuesInIteratorRangeGenerator(ForwardIterator begin, ForwardIterator end) : container_(begin, end) {} ~ValuesInIteratorRangeGenerator() override {} ParamIteratorInterface* Begin() const override { return new Iterator(this, container_.begin()); } ParamIteratorInterface* End() const override { return new Iterator(this, container_.end()); } private: typedef typename ::std::vector ContainerType; class Iterator : public ParamIteratorInterface { public: Iterator(const ParamGeneratorInterface* base, typename ContainerType::const_iterator iterator) : base_(base), iterator_(iterator) {} ~Iterator() override {} const ParamGeneratorInterface* BaseGenerator() const override { return base_; } void Advance() override { ++iterator_; value_.reset(); } ParamIteratorInterface* Clone() const override { return new Iterator(*this); } // We need to use cached value referenced by iterator_ because *iterator_ // can return a temporary object (and of type other then T), so just // having "return &*iterator_;" doesn't work. // value_ is updated here and not in Advance() because Advance() // can advance iterator_ beyond the end of the range, and we cannot // detect that fact. The client code, on the other hand, is // responsible for not calling Current() on an out-of-range iterator. const T* Current() const override { if (value_.get() == nullptr) value_.reset(new T(*iterator_)); return value_.get(); } bool Equals(const ParamIteratorInterface& other) const override { // Having the same base generator guarantees that the other // iterator is of the same type and we can downcast. GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) << "The program attempted to compare iterators " << "from different generators." << std::endl; return iterator_ == CheckedDowncastToActualType(&other)->iterator_; } private: Iterator(const Iterator& other) // The explicit constructor call suppresses a false warning // emitted by gcc when supplied with the -Wextra option. : ParamIteratorInterface(), base_(other.base_), iterator_(other.iterator_) {} const ParamGeneratorInterface* const base_; typename ContainerType::const_iterator iterator_; // A cached value of *iterator_. We keep it here to allow access by // pointer in the wrapping iterator's operator->(). // value_ needs to be mutable to be accessed in Current(). // Use of std::unique_ptr helps manage cached value's lifetime, // which is bound by the lifespan of the iterator itself. mutable std::unique_ptr value_; }; // class ValuesInIteratorRangeGenerator::Iterator // No implementation - assignment is unsupported. void operator=(const ValuesInIteratorRangeGenerator& other); const ContainerType container_; }; // class ValuesInIteratorRangeGenerator // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // Default parameterized test name generator, returns a string containing the // integer test parameter index. template std::string DefaultParamName(const TestParamInfo& info) { Message name_stream; name_stream << info.index; return name_stream.GetString(); } template void TestNotEmpty() { static_assert(sizeof(T) == 0, "Empty arguments are not allowed."); } template void TestNotEmpty(const T&) {} // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // Stores a parameter value and later creates tests parameterized with that // value. template class ParameterizedTestFactory : public TestFactoryBase { public: typedef typename TestClass::ParamType ParamType; explicit ParameterizedTestFactory(ParamType parameter) : parameter_(parameter) {} Test* CreateTest() override { TestClass::SetParam(¶meter_); return new TestClass(); } private: const ParamType parameter_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestFactory); }; // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // TestMetaFactoryBase is a base class for meta-factories that create // test factories for passing into MakeAndRegisterTestInfo function. template class TestMetaFactoryBase { public: virtual ~TestMetaFactoryBase() {} virtual TestFactoryBase* CreateTestFactory(ParamType parameter) = 0; }; // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // TestMetaFactory creates test factories for passing into // MakeAndRegisterTestInfo function. Since MakeAndRegisterTestInfo receives // ownership of test factory pointer, same factory object cannot be passed // into that method twice. But ParameterizedTestSuiteInfo is going to call // it for each Test/Parameter value combination. Thus it needs meta factory // creator class. template class TestMetaFactory : public TestMetaFactoryBase { public: using ParamType = typename TestSuite::ParamType; TestMetaFactory() {} TestFactoryBase* CreateTestFactory(ParamType parameter) override { return new ParameterizedTestFactory(parameter); } private: GTEST_DISALLOW_COPY_AND_ASSIGN_(TestMetaFactory); }; // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // ParameterizedTestSuiteInfoBase is a generic interface // to ParameterizedTestSuiteInfo classes. ParameterizedTestSuiteInfoBase // accumulates test information provided by TEST_P macro invocations // and generators provided by INSTANTIATE_TEST_SUITE_P macro invocations // and uses that information to register all resulting test instances // in RegisterTests method. The ParameterizeTestSuiteRegistry class holds // a collection of pointers to the ParameterizedTestSuiteInfo objects // and calls RegisterTests() on each of them when asked. class ParameterizedTestSuiteInfoBase { public: virtual ~ParameterizedTestSuiteInfoBase() {} // Base part of test suite name for display purposes. virtual const std::string& GetTestSuiteName() const = 0; // Test case id to verify identity. virtual TypeId GetTestSuiteTypeId() const = 0; // UnitTest class invokes this method to register tests in this // test suite right before running them in RUN_ALL_TESTS macro. // This method should not be called more than once on any single // instance of a ParameterizedTestSuiteInfoBase derived class. virtual void RegisterTests() = 0; protected: ParameterizedTestSuiteInfoBase() {} private: GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestSuiteInfoBase); }; // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // ParameterizedTestSuiteInfo accumulates tests obtained from TEST_P // macro invocations for a particular test suite and generators // obtained from INSTANTIATE_TEST_SUITE_P macro invocations for that // test suite. It registers tests with all values generated by all // generators when asked. template class ParameterizedTestSuiteInfo : public ParameterizedTestSuiteInfoBase { public: // ParamType and GeneratorCreationFunc are private types but are required // for declarations of public methods AddTestPattern() and // AddTestSuiteInstantiation(). using ParamType = typename TestSuite::ParamType; // A function that returns an instance of appropriate generator type. typedef ParamGenerator(GeneratorCreationFunc)(); using ParamNameGeneratorFunc = std::string(const TestParamInfo&); explicit ParameterizedTestSuiteInfo(const char* name, CodeLocation code_location) : test_suite_name_(name), code_location_(code_location) {} // Test case base name for display purposes. const std::string& GetTestSuiteName() const override { return test_suite_name_; } // Test case id to verify identity. TypeId GetTestSuiteTypeId() const override { return GetTypeId(); } // TEST_P macro uses AddTestPattern() to record information // about a single test in a LocalTestInfo structure. // test_suite_name is the base name of the test suite (without invocation // prefix). test_base_name is the name of an individual test without // parameter index. For the test SequenceA/FooTest.DoBar/1 FooTest is // test suite base name and DoBar is test base name. void AddTestPattern(const char* test_suite_name, const char* test_base_name, TestMetaFactoryBase* meta_factory) { tests_.push_back(std::shared_ptr( new TestInfo(test_suite_name, test_base_name, meta_factory))); } // INSTANTIATE_TEST_SUITE_P macro uses AddGenerator() to record information // about a generator. int AddTestSuiteInstantiation(const std::string& instantiation_name, GeneratorCreationFunc* func, ParamNameGeneratorFunc* name_func, const char* file, int line) { instantiations_.push_back( InstantiationInfo(instantiation_name, func, name_func, file, line)); return 0; // Return value used only to run this method in namespace scope. } // UnitTest class invokes this method to register tests in this test suite // test suites right before running tests in RUN_ALL_TESTS macro. // This method should not be called more than once on any single // instance of a ParameterizedTestSuiteInfoBase derived class. // UnitTest has a guard to prevent from calling this method more than once. void RegisterTests() override { for (typename TestInfoContainer::iterator test_it = tests_.begin(); test_it != tests_.end(); ++test_it) { std::shared_ptr test_info = *test_it; for (typename InstantiationContainer::iterator gen_it = instantiations_.begin(); gen_it != instantiations_.end(); ++gen_it) { const std::string& instantiation_name = gen_it->name; ParamGenerator generator((*gen_it->generator)()); ParamNameGeneratorFunc* name_func = gen_it->name_func; const char* file = gen_it->file; int line = gen_it->line; std::string test_suite_name; if ( !instantiation_name.empty() ) test_suite_name = instantiation_name + "/"; test_suite_name += test_info->test_suite_base_name; size_t i = 0; std::set test_param_names; for (typename ParamGenerator::iterator param_it = generator.begin(); param_it != generator.end(); ++param_it, ++i) { Message test_name_stream; std::string param_name = name_func( TestParamInfo(*param_it, i)); GTEST_CHECK_(IsValidParamName(param_name)) << "Parameterized test name '" << param_name << "' is invalid, in " << file << " line " << line << std::endl; GTEST_CHECK_(test_param_names.count(param_name) == 0) << "Duplicate parameterized test name '" << param_name << "', in " << file << " line " << line << std::endl; test_param_names.insert(param_name); if (!test_info->test_base_name.empty()) { test_name_stream << test_info->test_base_name << "/"; } test_name_stream << param_name; MakeAndRegisterTestInfo( test_suite_name.c_str(), test_name_stream.GetString().c_str(), nullptr, // No type parameter. PrintToString(*param_it).c_str(), code_location_, GetTestSuiteTypeId(), SuiteApiResolver::GetSetUpCaseOrSuite(file, line), SuiteApiResolver::GetTearDownCaseOrSuite(file, line), test_info->test_meta_factory->CreateTestFactory(*param_it)); } // for param_it } // for gen_it } // for test_it } // RegisterTests private: // LocalTestInfo structure keeps information about a single test registered // with TEST_P macro. struct TestInfo { TestInfo(const char* a_test_suite_base_name, const char* a_test_base_name, TestMetaFactoryBase* a_test_meta_factory) : test_suite_base_name(a_test_suite_base_name), test_base_name(a_test_base_name), test_meta_factory(a_test_meta_factory) {} const std::string test_suite_base_name; const std::string test_base_name; const std::unique_ptr > test_meta_factory; }; using TestInfoContainer = ::std::vector >; // Records data received from INSTANTIATE_TEST_SUITE_P macros: // struct InstantiationInfo { InstantiationInfo(const std::string &name_in, GeneratorCreationFunc* generator_in, ParamNameGeneratorFunc* name_func_in, const char* file_in, int line_in) : name(name_in), generator(generator_in), name_func(name_func_in), file(file_in), line(line_in) {} std::string name; GeneratorCreationFunc* generator; ParamNameGeneratorFunc* name_func; const char* file; int line; }; typedef ::std::vector InstantiationContainer; static bool IsValidParamName(const std::string& name) { // Check for empty string if (name.empty()) return false; // Check for invalid characters for (std::string::size_type index = 0; index < name.size(); ++index) { if (!isalnum(name[index]) && name[index] != '_') return false; } return true; } const std::string test_suite_name_; CodeLocation code_location_; TestInfoContainer tests_; InstantiationContainer instantiations_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestSuiteInfo); }; // class ParameterizedTestSuiteInfo // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ template using ParameterizedTestCaseInfo = ParameterizedTestSuiteInfo; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // ParameterizedTestSuiteRegistry contains a map of // ParameterizedTestSuiteInfoBase classes accessed by test suite names. TEST_P // and INSTANTIATE_TEST_SUITE_P macros use it to locate their corresponding // ParameterizedTestSuiteInfo descriptors. class ParameterizedTestSuiteRegistry { public: ParameterizedTestSuiteRegistry() {} ~ParameterizedTestSuiteRegistry() { for (auto& test_suite_info : test_suite_infos_) { delete test_suite_info; } } // Looks up or creates and returns a structure containing information about // tests and instantiations of a particular test suite. template ParameterizedTestSuiteInfo* GetTestSuitePatternHolder( const char* test_suite_name, CodeLocation code_location) { ParameterizedTestSuiteInfo* typed_test_info = nullptr; for (auto& test_suite_info : test_suite_infos_) { if (test_suite_info->GetTestSuiteName() == test_suite_name) { if (test_suite_info->GetTestSuiteTypeId() != GetTypeId()) { // Complain about incorrect usage of Google Test facilities // and terminate the program since we cannot guaranty correct // test suite setup and tear-down in this case. ReportInvalidTestSuiteType(test_suite_name, code_location); posix::Abort(); } else { // At this point we are sure that the object we found is of the same // type we are looking for, so we downcast it to that type // without further checks. typed_test_info = CheckedDowncastToActualType< ParameterizedTestSuiteInfo >(test_suite_info); } break; } } if (typed_test_info == nullptr) { typed_test_info = new ParameterizedTestSuiteInfo( test_suite_name, code_location); test_suite_infos_.push_back(typed_test_info); } return typed_test_info; } void RegisterTests() { for (auto& test_suite_info : test_suite_infos_) { test_suite_info->RegisterTests(); } } // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ template ParameterizedTestCaseInfo* GetTestCasePatternHolder( const char* test_case_name, CodeLocation code_location) { return GetTestSuitePatternHolder(test_case_name, code_location); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ private: using TestSuiteInfoContainer = ::std::vector; TestSuiteInfoContainer test_suite_infos_; GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestSuiteRegistry); }; } // namespace internal // Forward declarations of ValuesIn(), which is implemented in // include/gtest/gtest-param-test.h. template internal::ParamGenerator ValuesIn( const Container& container); namespace internal { // Used in the Values() function to provide polymorphic capabilities. template class ValueArray { public: ValueArray(Ts... v) : v_{std::move(v)...} {} template operator ParamGenerator() const { // NOLINT return ValuesIn(MakeVector(MakeIndexSequence())); } private: template std::vector MakeVector(IndexSequence) const { return std::vector{static_cast(v_.template Get())...}; } FlatTuple v_; }; template class CartesianProductGenerator : public ParamGeneratorInterface<::std::tuple> { public: typedef ::std::tuple ParamType; CartesianProductGenerator(const std::tuple...>& g) : generators_(g) {} ~CartesianProductGenerator() override {} ParamIteratorInterface* Begin() const override { return new Iterator(this, generators_, false); } ParamIteratorInterface* End() const override { return new Iterator(this, generators_, true); } private: template class IteratorImpl; template class IteratorImpl> : public ParamIteratorInterface { public: IteratorImpl(const ParamGeneratorInterface* base, const std::tuple...>& generators, bool is_end) : base_(base), begin_(std::get(generators).begin()...), end_(std::get(generators).end()...), current_(is_end ? end_ : begin_) { ComputeCurrentValue(); } ~IteratorImpl() override {} const ParamGeneratorInterface* BaseGenerator() const override { return base_; } // Advance should not be called on beyond-of-range iterators // so no component iterators must be beyond end of range, either. void Advance() override { assert(!AtEnd()); // Advance the last iterator. ++std::get(current_); // if that reaches end, propagate that up. AdvanceIfEnd(); ComputeCurrentValue(); } ParamIteratorInterface* Clone() const override { return new IteratorImpl(*this); } const ParamType* Current() const override { return current_value_.get(); } bool Equals(const ParamIteratorInterface& other) const override { // Having the same base generator guarantees that the other // iterator is of the same type and we can downcast. GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) << "The program attempted to compare iterators " << "from different generators." << std::endl; const IteratorImpl* typed_other = CheckedDowncastToActualType(&other); // We must report iterators equal if they both point beyond their // respective ranges. That can happen in a variety of fashions, // so we have to consult AtEnd(). if (AtEnd() && typed_other->AtEnd()) return true; bool same = true; bool dummy[] = { (same = same && std::get(current_) == std::get(typed_other->current_))...}; (void)dummy; return same; } private: template void AdvanceIfEnd() { if (std::get(current_) != std::get(end_)) return; bool last = ThisI == 0; if (last) { // We are done. Nothing else to propagate. return; } constexpr size_t NextI = ThisI - (ThisI != 0); std::get(current_) = std::get(begin_); ++std::get(current_); AdvanceIfEnd(); } void ComputeCurrentValue() { if (!AtEnd()) current_value_ = std::make_shared(*std::get(current_)...); } bool AtEnd() const { bool at_end = false; bool dummy[] = { (at_end = at_end || std::get(current_) == std::get(end_))...}; (void)dummy; return at_end; } const ParamGeneratorInterface* const base_; std::tuple::iterator...> begin_; std::tuple::iterator...> end_; std::tuple::iterator...> current_; std::shared_ptr current_value_; }; using Iterator = IteratorImpl::type>; std::tuple...> generators_; }; template class CartesianProductHolder { public: CartesianProductHolder(const Gen&... g) : generators_(g...) {} template operator ParamGenerator<::std::tuple>() const { return ParamGenerator<::std::tuple>( new CartesianProductGenerator(generators_)); } private: std::tuple generators_; }; } // namespace internal } // namespace testing #endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ namespace testing { // Functions producing parameter generators. // // Google Test uses these generators to produce parameters for value- // parameterized tests. When a parameterized test suite is instantiated // with a particular generator, Google Test creates and runs tests // for each element in the sequence produced by the generator. // // In the following sample, tests from test suite FooTest are instantiated // each three times with parameter values 3, 5, and 8: // // class FooTest : public TestWithParam { ... }; // // TEST_P(FooTest, TestThis) { // } // TEST_P(FooTest, TestThat) { // } // INSTANTIATE_TEST_SUITE_P(TestSequence, FooTest, Values(3, 5, 8)); // // Range() returns generators providing sequences of values in a range. // // Synopsis: // Range(start, end) // - returns a generator producing a sequence of values {start, start+1, // start+2, ..., }. // Range(start, end, step) // - returns a generator producing a sequence of values {start, start+step, // start+step+step, ..., }. // Notes: // * The generated sequences never include end. For example, Range(1, 5) // returns a generator producing a sequence {1, 2, 3, 4}. Range(1, 9, 2) // returns a generator producing {1, 3, 5, 7}. // * start and end must have the same type. That type may be any integral or // floating-point type or a user defined type satisfying these conditions: // * It must be assignable (have operator=() defined). // * It must have operator+() (operator+(int-compatible type) for // two-operand version). // * It must have operator<() defined. // Elements in the resulting sequences will also have that type. // * Condition start < end must be satisfied in order for resulting sequences // to contain any elements. // template internal::ParamGenerator Range(T start, T end, IncrementT step) { return internal::ParamGenerator( new internal::RangeGenerator(start, end, step)); } template internal::ParamGenerator Range(T start, T end) { return Range(start, end, 1); } // ValuesIn() function allows generation of tests with parameters coming from // a container. // // Synopsis: // ValuesIn(const T (&array)[N]) // - returns a generator producing sequences with elements from // a C-style array. // ValuesIn(const Container& container) // - returns a generator producing sequences with elements from // an STL-style container. // ValuesIn(Iterator begin, Iterator end) // - returns a generator producing sequences with elements from // a range [begin, end) defined by a pair of STL-style iterators. These // iterators can also be plain C pointers. // // Please note that ValuesIn copies the values from the containers // passed in and keeps them to generate tests in RUN_ALL_TESTS(). // // Examples: // // This instantiates tests from test suite StringTest // each with C-string values of "foo", "bar", and "baz": // // const char* strings[] = {"foo", "bar", "baz"}; // INSTANTIATE_TEST_SUITE_P(StringSequence, StringTest, ValuesIn(strings)); // // This instantiates tests from test suite StlStringTest // each with STL strings with values "a" and "b": // // ::std::vector< ::std::string> GetParameterStrings() { // ::std::vector< ::std::string> v; // v.push_back("a"); // v.push_back("b"); // return v; // } // // INSTANTIATE_TEST_SUITE_P(CharSequence, // StlStringTest, // ValuesIn(GetParameterStrings())); // // // This will also instantiate tests from CharTest // each with parameter values 'a' and 'b': // // ::std::list GetParameterChars() { // ::std::list list; // list.push_back('a'); // list.push_back('b'); // return list; // } // ::std::list l = GetParameterChars(); // INSTANTIATE_TEST_SUITE_P(CharSequence2, // CharTest, // ValuesIn(l.begin(), l.end())); // template internal::ParamGenerator< typename std::iterator_traits::value_type> ValuesIn(ForwardIterator begin, ForwardIterator end) { typedef typename std::iterator_traits::value_type ParamType; return internal::ParamGenerator( new internal::ValuesInIteratorRangeGenerator(begin, end)); } template internal::ParamGenerator ValuesIn(const T (&array)[N]) { return ValuesIn(array, array + N); } template internal::ParamGenerator ValuesIn( const Container& container) { return ValuesIn(container.begin(), container.end()); } // Values() allows generating tests from explicitly specified list of // parameters. // // Synopsis: // Values(T v1, T v2, ..., T vN) // - returns a generator producing sequences with elements v1, v2, ..., vN. // // For example, this instantiates tests from test suite BarTest each // with values "one", "two", and "three": // // INSTANTIATE_TEST_SUITE_P(NumSequence, // BarTest, // Values("one", "two", "three")); // // This instantiates tests from test suite BazTest each with values 1, 2, 3.5. // The exact type of values will depend on the type of parameter in BazTest. // // INSTANTIATE_TEST_SUITE_P(FloatingNumbers, BazTest, Values(1, 2, 3.5)); // // template internal::ValueArray Values(T... v) { return internal::ValueArray(std::move(v)...); } // Bool() allows generating tests with parameters in a set of (false, true). // // Synopsis: // Bool() // - returns a generator producing sequences with elements {false, true}. // // It is useful when testing code that depends on Boolean flags. Combinations // of multiple flags can be tested when several Bool()'s are combined using // Combine() function. // // In the following example all tests in the test suite FlagDependentTest // will be instantiated twice with parameters false and true. // // class FlagDependentTest : public testing::TestWithParam { // virtual void SetUp() { // external_flag = GetParam(); // } // } // INSTANTIATE_TEST_SUITE_P(BoolSequence, FlagDependentTest, Bool()); // inline internal::ParamGenerator Bool() { return Values(false, true); } // Combine() allows the user to combine two or more sequences to produce // values of a Cartesian product of those sequences' elements. // // Synopsis: // Combine(gen1, gen2, ..., genN) // - returns a generator producing sequences with elements coming from // the Cartesian product of elements from the sequences generated by // gen1, gen2, ..., genN. The sequence elements will have a type of // std::tuple where T1, T2, ..., TN are the types // of elements from sequences produces by gen1, gen2, ..., genN. // // Combine can have up to 10 arguments. // // Example: // // This will instantiate tests in test suite AnimalTest each one with // the parameter values tuple("cat", BLACK), tuple("cat", WHITE), // tuple("dog", BLACK), and tuple("dog", WHITE): // // enum Color { BLACK, GRAY, WHITE }; // class AnimalTest // : public testing::TestWithParam > {...}; // // TEST_P(AnimalTest, AnimalLooksNice) {...} // // INSTANTIATE_TEST_SUITE_P(AnimalVariations, AnimalTest, // Combine(Values("cat", "dog"), // Values(BLACK, WHITE))); // // This will instantiate tests in FlagDependentTest with all variations of two // Boolean flags: // // class FlagDependentTest // : public testing::TestWithParam > { // virtual void SetUp() { // // Assigns external_flag_1 and external_flag_2 values from the tuple. // std::tie(external_flag_1, external_flag_2) = GetParam(); // } // }; // // TEST_P(FlagDependentTest, TestFeature1) { // // Test your code using external_flag_1 and external_flag_2 here. // } // INSTANTIATE_TEST_SUITE_P(TwoBoolSequence, FlagDependentTest, // Combine(Bool(), Bool())); // template internal::CartesianProductHolder Combine(const Generator&... g) { return internal::CartesianProductHolder(g...); } #define TEST_P(test_suite_name, test_name) \ class GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \ : public test_suite_name { \ public: \ GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)() {} \ virtual void TestBody(); \ \ private: \ static int AddToRegistry() { \ ::testing::UnitTest::GetInstance() \ ->parameterized_test_registry() \ .GetTestSuitePatternHolder( \ #test_suite_name, \ ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ ->AddTestPattern( \ GTEST_STRINGIFY_(test_suite_name), GTEST_STRINGIFY_(test_name), \ new ::testing::internal::TestMetaFactory()); \ return 0; \ } \ static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \ GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_suite_name, \ test_name)); \ }; \ int GTEST_TEST_CLASS_NAME_(test_suite_name, \ test_name)::gtest_registering_dummy_ = \ GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::AddToRegistry(); \ void GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::TestBody() // The last argument to INSTANTIATE_TEST_SUITE_P allows the user to specify // generator and an optional function or functor that generates custom test name // suffixes based on the test parameters. Such a function or functor should // accept one argument of type testing::TestParamInfo, and // return std::string. // // testing::PrintToStringParamName is a builtin test suffix generator that // returns the value of testing::PrintToString(GetParam()). // // Note: test names must be non-empty, unique, and may only contain ASCII // alphanumeric characters or underscore. Because PrintToString adds quotes // to std::string and C strings, it won't work for these types. #define GTEST_EXPAND_(arg) arg #define GTEST_GET_FIRST_(first, ...) first #define GTEST_GET_SECOND_(first, second, ...) second #define INSTANTIATE_TEST_SUITE_P(prefix, test_suite_name, ...) \ static ::testing::internal::ParamGenerator \ gtest_##prefix##test_suite_name##_EvalGenerator_() { \ return GTEST_EXPAND_(GTEST_GET_FIRST_(__VA_ARGS__, DUMMY_PARAM_)); \ } \ static ::std::string gtest_##prefix##test_suite_name##_EvalGenerateName_( \ const ::testing::TestParamInfo& info) { \ if (::testing::internal::AlwaysFalse()) { \ ::testing::internal::TestNotEmpty(GTEST_EXPAND_(GTEST_GET_SECOND_( \ __VA_ARGS__, \ ::testing::internal::DefaultParamName, \ DUMMY_PARAM_))); \ auto t = std::make_tuple(__VA_ARGS__); \ static_assert(std::tuple_size::value <= 2, \ "Too Many Args!"); \ } \ return ((GTEST_EXPAND_(GTEST_GET_SECOND_( \ __VA_ARGS__, \ ::testing::internal::DefaultParamName, \ DUMMY_PARAM_))))(info); \ } \ static int gtest_##prefix##test_suite_name##_dummy_ \ GTEST_ATTRIBUTE_UNUSED_ = \ ::testing::UnitTest::GetInstance() \ ->parameterized_test_registry() \ .GetTestSuitePatternHolder( \ #test_suite_name, \ ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ ->AddTestSuiteInstantiation( \ #prefix, >est_##prefix##test_suite_name##_EvalGenerator_, \ >est_##prefix##test_suite_name##_EvalGenerateName_, \ __FILE__, __LINE__) // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define INSTANTIATE_TEST_CASE_P \ static_assert(::testing::internal::InstantiateTestCase_P_IsDeprecated(), \ ""); \ INSTANTIATE_TEST_SUITE_P #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ } // namespace testing #endif // GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ // Copyright 2006, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // Google C++ Testing and Mocking Framework definitions useful in production code. // GOOGLETEST_CM0003 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_PROD_H_ #define GTEST_INCLUDE_GTEST_GTEST_PROD_H_ // When you need to test the private or protected members of a class, // use the FRIEND_TEST macro to declare your tests as friends of the // class. For example: // // class MyClass { // private: // void PrivateMethod(); // FRIEND_TEST(MyClassTest, PrivateMethodWorks); // }; // // class MyClassTest : public testing::Test { // // ... // }; // // TEST_F(MyClassTest, PrivateMethodWorks) { // // Can call MyClass::PrivateMethod() here. // } // // Note: The test class must be in the same namespace as the class being tested. // For example, putting MyClassTest in an anonymous namespace will not work. #define FRIEND_TEST(test_case_name, test_name)\ friend class test_case_name##_##test_name##_Test #endif // GTEST_INCLUDE_GTEST_GTEST_PROD_H_ // Copyright 2008, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ #define GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ #include #include GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) namespace testing { // A copyable object representing the result of a test part (i.e. an // assertion or an explicit FAIL(), ADD_FAILURE(), or SUCCESS()). // // Don't inherit from TestPartResult as its destructor is not virtual. class GTEST_API_ TestPartResult { public: // The possible outcomes of a test part (i.e. an assertion or an // explicit SUCCEED(), FAIL(), or ADD_FAILURE()). enum Type { kSuccess, // Succeeded. kNonFatalFailure, // Failed but the test can continue. kFatalFailure, // Failed and the test should be terminated. kSkip // Skipped. }; // C'tor. TestPartResult does NOT have a default constructor. // Always use this constructor (with parameters) to create a // TestPartResult object. TestPartResult(Type a_type, const char* a_file_name, int a_line_number, const char* a_message) : type_(a_type), file_name_(a_file_name == nullptr ? "" : a_file_name), line_number_(a_line_number), summary_(ExtractSummary(a_message)), message_(a_message) {} // Gets the outcome of the test part. Type type() const { return type_; } // Gets the name of the source file where the test part took place, or // NULL if it's unknown. const char* file_name() const { return file_name_.empty() ? nullptr : file_name_.c_str(); } // Gets the line in the source file where the test part took place, // or -1 if it's unknown. int line_number() const { return line_number_; } // Gets the summary of the failure message. const char* summary() const { return summary_.c_str(); } // Gets the message associated with the test part. const char* message() const { return message_.c_str(); } // Returns true if and only if the test part was skipped. bool skipped() const { return type_ == kSkip; } // Returns true if and only if the test part passed. bool passed() const { return type_ == kSuccess; } // Returns true if and only if the test part non-fatally failed. bool nonfatally_failed() const { return type_ == kNonFatalFailure; } // Returns true if and only if the test part fatally failed. bool fatally_failed() const { return type_ == kFatalFailure; } // Returns true if and only if the test part failed. bool failed() const { return fatally_failed() || nonfatally_failed(); } private: Type type_; // Gets the summary of the failure message by omitting the stack // trace in it. static std::string ExtractSummary(const char* message); // The name of the source file where the test part took place, or // "" if the source file is unknown. std::string file_name_; // The line in the source file where the test part took place, or -1 // if the line number is unknown. int line_number_; std::string summary_; // The test failure summary. std::string message_; // The test failure message. }; // Prints a TestPartResult object. std::ostream& operator<<(std::ostream& os, const TestPartResult& result); // An array of TestPartResult objects. // // Don't inherit from TestPartResultArray as its destructor is not // virtual. class GTEST_API_ TestPartResultArray { public: TestPartResultArray() {} // Appends the given TestPartResult to the array. void Append(const TestPartResult& result); // Returns the TestPartResult at the given index (0-based). const TestPartResult& GetTestPartResult(int index) const; // Returns the number of TestPartResult objects in the array. int size() const; private: std::vector array_; GTEST_DISALLOW_COPY_AND_ASSIGN_(TestPartResultArray); }; // This interface knows how to report a test part result. class GTEST_API_ TestPartResultReporterInterface { public: virtual ~TestPartResultReporterInterface() {} virtual void ReportTestPartResult(const TestPartResult& result) = 0; }; namespace internal { // This helper class is used by {ASSERT|EXPECT}_NO_FATAL_FAILURE to check if a // statement generates new fatal failures. To do so it registers itself as the // current test part result reporter. Besides checking if fatal failures were // reported, it only delegates the reporting to the former result reporter. // The original result reporter is restored in the destructor. // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. class GTEST_API_ HasNewFatalFailureHelper : public TestPartResultReporterInterface { public: HasNewFatalFailureHelper(); ~HasNewFatalFailureHelper() override; void ReportTestPartResult(const TestPartResult& result) override; bool has_new_fatal_failure() const { return has_new_fatal_failure_; } private: bool has_new_fatal_failure_; TestPartResultReporterInterface* original_reporter_; GTEST_DISALLOW_COPY_AND_ASSIGN_(HasNewFatalFailureHelper); }; } // namespace internal } // namespace testing GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 #endif // GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ // Copyright 2008 Google Inc. // All Rights Reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ #define GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ // This header implements typed tests and type-parameterized tests. // Typed (aka type-driven) tests repeat the same test for types in a // list. You must know which types you want to test with when writing // typed tests. Here's how you do it: #if 0 // First, define a fixture class template. It should be parameterized // by a type. Remember to derive it from testing::Test. template class FooTest : public testing::Test { public: ... typedef std::list List; static T shared_; T value_; }; // Next, associate a list of types with the test suite, which will be // repeated for each type in the list. The typedef is necessary for // the macro to parse correctly. typedef testing::Types MyTypes; TYPED_TEST_SUITE(FooTest, MyTypes); // If the type list contains only one type, you can write that type // directly without Types<...>: // TYPED_TEST_SUITE(FooTest, int); // Then, use TYPED_TEST() instead of TEST_F() to define as many typed // tests for this test suite as you want. TYPED_TEST(FooTest, DoesBlah) { // Inside a test, refer to the special name TypeParam to get the type // parameter. Since we are inside a derived class template, C++ requires // us to visit the members of FooTest via 'this'. TypeParam n = this->value_; // To visit static members of the fixture, add the TestFixture:: // prefix. n += TestFixture::shared_; // To refer to typedefs in the fixture, add the "typename // TestFixture::" prefix. typename TestFixture::List values; values.push_back(n); ... } TYPED_TEST(FooTest, HasPropertyA) { ... } // TYPED_TEST_SUITE takes an optional third argument which allows to specify a // class that generates custom test name suffixes based on the type. This should // be a class which has a static template function GetName(int index) returning // a string for each type. The provided integer index equals the index of the // type in the provided type list. In many cases the index can be ignored. // // For example: // class MyTypeNames { // public: // template // static std::string GetName(int) { // if (std::is_same()) return "char"; // if (std::is_same()) return "int"; // if (std::is_same()) return "unsignedInt"; // } // }; // TYPED_TEST_SUITE(FooTest, MyTypes, MyTypeNames); #endif // 0 // Type-parameterized tests are abstract test patterns parameterized // by a type. Compared with typed tests, type-parameterized tests // allow you to define the test pattern without knowing what the type // parameters are. The defined pattern can be instantiated with // different types any number of times, in any number of translation // units. // // If you are designing an interface or concept, you can define a // suite of type-parameterized tests to verify properties that any // valid implementation of the interface/concept should have. Then, // each implementation can easily instantiate the test suite to verify // that it conforms to the requirements, without having to write // similar tests repeatedly. Here's an example: #if 0 // First, define a fixture class template. It should be parameterized // by a type. Remember to derive it from testing::Test. template class FooTest : public testing::Test { ... }; // Next, declare that you will define a type-parameterized test suite // (the _P suffix is for "parameterized" or "pattern", whichever you // prefer): TYPED_TEST_SUITE_P(FooTest); // Then, use TYPED_TEST_P() to define as many type-parameterized tests // for this type-parameterized test suite as you want. TYPED_TEST_P(FooTest, DoesBlah) { // Inside a test, refer to TypeParam to get the type parameter. TypeParam n = 0; ... } TYPED_TEST_P(FooTest, HasPropertyA) { ... } // Now the tricky part: you need to register all test patterns before // you can instantiate them. The first argument of the macro is the // test suite name; the rest are the names of the tests in this test // case. REGISTER_TYPED_TEST_SUITE_P(FooTest, DoesBlah, HasPropertyA); // Finally, you are free to instantiate the pattern with the types you // want. If you put the above code in a header file, you can #include // it in multiple C++ source files and instantiate it multiple times. // // To distinguish different instances of the pattern, the first // argument to the INSTANTIATE_* macro is a prefix that will be added // to the actual test suite name. Remember to pick unique prefixes for // different instances. typedef testing::Types MyTypes; INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes); // If the type list contains only one type, you can write that type // directly without Types<...>: // INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, int); // // Similar to the optional argument of TYPED_TEST_SUITE above, // INSTANTIATE_TEST_SUITE_P takes an optional fourth argument which allows to // generate custom names. // INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes, MyTypeNames); #endif // 0 // Implements typed tests. #if GTEST_HAS_TYPED_TEST // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // Expands to the name of the typedef for the type parameters of the // given test suite. #define GTEST_TYPE_PARAMS_(TestSuiteName) gtest_type_params_##TestSuiteName##_ // Expands to the name of the typedef for the NameGenerator, responsible for // creating the suffixes of the name. #define GTEST_NAME_GENERATOR_(TestSuiteName) \ gtest_type_params_##TestSuiteName##_NameGenerator #define TYPED_TEST_SUITE(CaseName, Types, ...) \ typedef ::testing::internal::TypeList::type GTEST_TYPE_PARAMS_( \ CaseName); \ typedef ::testing::internal::NameGeneratorSelector<__VA_ARGS__>::type \ GTEST_NAME_GENERATOR_(CaseName) # define TYPED_TEST(CaseName, TestName) \ template \ class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ : public CaseName { \ private: \ typedef CaseName TestFixture; \ typedef gtest_TypeParam_ TypeParam; \ virtual void TestBody(); \ }; \ static bool gtest_##CaseName##_##TestName##_registered_ \ GTEST_ATTRIBUTE_UNUSED_ = \ ::testing::internal::TypeParameterizedTest< \ CaseName, \ ::testing::internal::TemplateSel, \ GTEST_TYPE_PARAMS_( \ CaseName)>::Register("", \ ::testing::internal::CodeLocation( \ __FILE__, __LINE__), \ #CaseName, #TestName, 0, \ ::testing::internal::GenerateNames< \ GTEST_NAME_GENERATOR_(CaseName), \ GTEST_TYPE_PARAMS_(CaseName)>()); \ template \ void GTEST_TEST_CLASS_NAME_(CaseName, \ TestName)::TestBody() // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define TYPED_TEST_CASE \ static_assert(::testing::internal::TypedTestCaseIsDeprecated(), ""); \ TYPED_TEST_SUITE #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #endif // GTEST_HAS_TYPED_TEST // Implements type-parameterized tests. #if GTEST_HAS_TYPED_TEST_P // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // Expands to the namespace name that the type-parameterized tests for // the given type-parameterized test suite are defined in. The exact // name of the namespace is subject to change without notice. #define GTEST_SUITE_NAMESPACE_(TestSuiteName) gtest_suite_##TestSuiteName##_ // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // // Expands to the name of the variable used to remember the names of // the defined tests in the given test suite. #define GTEST_TYPED_TEST_SUITE_P_STATE_(TestSuiteName) \ gtest_typed_test_suite_p_state_##TestSuiteName##_ // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE DIRECTLY. // // Expands to the name of the variable used to remember the names of // the registered tests in the given test suite. #define GTEST_REGISTERED_TEST_NAMES_(TestSuiteName) \ gtest_registered_test_names_##TestSuiteName##_ // The variables defined in the type-parameterized test macros are // static as typically these macros are used in a .h file that can be // #included in multiple translation units linked together. #define TYPED_TEST_SUITE_P(SuiteName) \ static ::testing::internal::TypedTestSuitePState \ GTEST_TYPED_TEST_SUITE_P_STATE_(SuiteName) // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define TYPED_TEST_CASE_P \ static_assert(::testing::internal::TypedTestCase_P_IsDeprecated(), ""); \ TYPED_TEST_SUITE_P #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define TYPED_TEST_P(SuiteName, TestName) \ namespace GTEST_SUITE_NAMESPACE_(SuiteName) { \ template \ class TestName : public SuiteName { \ private: \ typedef SuiteName TestFixture; \ typedef gtest_TypeParam_ TypeParam; \ virtual void TestBody(); \ }; \ static bool gtest_##TestName##_defined_ GTEST_ATTRIBUTE_UNUSED_ = \ GTEST_TYPED_TEST_SUITE_P_STATE_(SuiteName).AddTestName( \ __FILE__, __LINE__, #SuiteName, #TestName); \ } \ template \ void GTEST_SUITE_NAMESPACE_( \ SuiteName)::TestName::TestBody() #define REGISTER_TYPED_TEST_SUITE_P(SuiteName, ...) \ namespace GTEST_SUITE_NAMESPACE_(SuiteName) { \ typedef ::testing::internal::Templates<__VA_ARGS__>::type gtest_AllTests_; \ } \ static const char* const GTEST_REGISTERED_TEST_NAMES_( \ SuiteName) GTEST_ATTRIBUTE_UNUSED_ = \ GTEST_TYPED_TEST_SUITE_P_STATE_(SuiteName).VerifyRegisteredTestNames( \ __FILE__, __LINE__, #__VA_ARGS__) // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define REGISTER_TYPED_TEST_CASE_P \ static_assert(::testing::internal::RegisterTypedTestCase_P_IsDeprecated(), \ ""); \ REGISTER_TYPED_TEST_SUITE_P #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, SuiteName, Types, ...) \ static bool gtest_##Prefix##_##SuiteName GTEST_ATTRIBUTE_UNUSED_ = \ ::testing::internal::TypeParameterizedTestSuite< \ SuiteName, GTEST_SUITE_NAMESPACE_(SuiteName)::gtest_AllTests_, \ ::testing::internal::TypeList::type>:: \ Register(#Prefix, \ ::testing::internal::CodeLocation(__FILE__, __LINE__), \ >EST_TYPED_TEST_SUITE_P_STATE_(SuiteName), #SuiteName, \ GTEST_REGISTERED_TEST_NAMES_(SuiteName), \ ::testing::internal::GenerateNames< \ ::testing::internal::NameGeneratorSelector< \ __VA_ARGS__>::type, \ ::testing::internal::TypeList::type>()) // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #define INSTANTIATE_TYPED_TEST_CASE_P \ static_assert( \ ::testing::internal::InstantiateTypedTestCase_P_IsDeprecated(), ""); \ INSTANTIATE_TYPED_TEST_SUITE_P #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ #endif // GTEST_HAS_TYPED_TEST_P #endif // GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ /* class A needs to have dll-interface to be used by clients of class B */) namespace testing { // Silence C4100 (unreferenced formal parameter) and 4805 // unsafe mix of type 'const int' and type 'const bool' #ifdef _MSC_VER # pragma warning(push) # pragma warning(disable:4805) # pragma warning(disable:4100) #endif // Declares the flags. // This flag temporary enables the disabled tests. GTEST_DECLARE_bool_(also_run_disabled_tests); // This flag brings the debugger on an assertion failure. GTEST_DECLARE_bool_(break_on_failure); // This flag controls whether Google Test catches all test-thrown exceptions // and logs them as failures. GTEST_DECLARE_bool_(catch_exceptions); // This flag enables using colors in terminal output. Available values are // "yes" to enable colors, "no" (disable colors), or "auto" (the default) // to let Google Test decide. GTEST_DECLARE_string_(color); // This flag sets up the filter to select by name using a glob pattern // the tests to run. If the filter is not given all tests are executed. GTEST_DECLARE_string_(filter); // This flag controls whether Google Test installs a signal handler that dumps // debugging information when fatal signals are raised. GTEST_DECLARE_bool_(install_failure_signal_handler); // This flag causes the Google Test to list tests. None of the tests listed // are actually run if the flag is provided. GTEST_DECLARE_bool_(list_tests); // This flag controls whether Google Test emits a detailed XML report to a file // in addition to its normal textual output. GTEST_DECLARE_string_(output); // This flags control whether Google Test prints the elapsed time for each // test. GTEST_DECLARE_bool_(print_time); // This flags control whether Google Test prints UTF8 characters as text. GTEST_DECLARE_bool_(print_utf8); // This flag specifies the random number seed. GTEST_DECLARE_int32_(random_seed); // This flag sets how many times the tests are repeated. The default value // is 1. If the value is -1 the tests are repeating forever. GTEST_DECLARE_int32_(repeat); // This flag controls whether Google Test includes Google Test internal // stack frames in failure stack traces. GTEST_DECLARE_bool_(show_internal_stack_frames); // When this flag is specified, tests' order is randomized on every iteration. GTEST_DECLARE_bool_(shuffle); // This flag specifies the maximum number of stack frames to be // printed in a failure message. GTEST_DECLARE_int32_(stack_trace_depth); // When this flag is specified, a failed assertion will throw an // exception if exceptions are enabled, or exit the program with a // non-zero code otherwise. For use with an external test framework. GTEST_DECLARE_bool_(throw_on_failure); // When this flag is specified, list of skipped test names is printed in // summary GTEST_DECLARE_bool_(print_skipped); // When this flag is set with a "host:port" string, on supported // platforms test results are streamed to the specified port on // the specified host machine. GTEST_DECLARE_string_(stream_result_to); #if GTEST_USE_OWN_FLAGFILE_FLAG_ GTEST_DECLARE_string_(flagfile); #endif // GTEST_USE_OWN_FLAGFILE_FLAG_ // The upper limit for valid stack trace depths. const int kMaxStackTraceDepth = 100; namespace internal { class AssertHelper; class DefaultGlobalTestPartResultReporter; class ExecDeathTest; class NoExecDeathTest; class FinalSuccessChecker; class GTestFlagSaver; class StreamingListenerTest; class TestResultAccessor; class TestEventListenersAccessor; class TestEventRepeater; class UnitTestRecordPropertyTestHelper; class WindowsDeathTest; class FuchsiaDeathTest; class UnitTestImpl* GetUnitTestImpl(); void ReportFailureInUnknownLocation(TestPartResult::Type result_type, const std::string& message); } // namespace internal // The friend relationship of some of these classes is cyclic. // If we don't forward declare them the compiler might confuse the classes // in friendship clauses with same named classes on the scope. class Test; class TestSuite; // Old API is still available but deprecated #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ using TestCase = TestSuite; #endif class TestInfo; class UnitTest; // A class for indicating whether an assertion was successful. When // the assertion wasn't successful, the AssertionResult object // remembers a non-empty message that describes how it failed. // // To create an instance of this class, use one of the factory functions // (AssertionSuccess() and AssertionFailure()). // // This class is useful for two purposes: // 1. Defining predicate functions to be used with Boolean test assertions // EXPECT_TRUE/EXPECT_FALSE and their ASSERT_ counterparts // 2. Defining predicate-format functions to be // used with predicate assertions (ASSERT_PRED_FORMAT*, etc). // // For example, if you define IsEven predicate: // // testing::AssertionResult IsEven(int n) { // if ((n % 2) == 0) // return testing::AssertionSuccess(); // else // return testing::AssertionFailure() << n << " is odd"; // } // // Then the failed expectation EXPECT_TRUE(IsEven(Fib(5))) // will print the message // // Value of: IsEven(Fib(5)) // Actual: false (5 is odd) // Expected: true // // instead of a more opaque // // Value of: IsEven(Fib(5)) // Actual: false // Expected: true // // in case IsEven is a simple Boolean predicate. // // If you expect your predicate to be reused and want to support informative // messages in EXPECT_FALSE and ASSERT_FALSE (negative assertions show up // about half as often as positive ones in our tests), supply messages for // both success and failure cases: // // testing::AssertionResult IsEven(int n) { // if ((n % 2) == 0) // return testing::AssertionSuccess() << n << " is even"; // else // return testing::AssertionFailure() << n << " is odd"; // } // // Then a statement EXPECT_FALSE(IsEven(Fib(6))) will print // // Value of: IsEven(Fib(6)) // Actual: true (8 is even) // Expected: false // // NB: Predicates that support negative Boolean assertions have reduced // performance in positive ones so be careful not to use them in tests // that have lots (tens of thousands) of positive Boolean assertions. // // To use this class with EXPECT_PRED_FORMAT assertions such as: // // // Verifies that Foo() returns an even number. // EXPECT_PRED_FORMAT1(IsEven, Foo()); // // you need to define: // // testing::AssertionResult IsEven(const char* expr, int n) { // if ((n % 2) == 0) // return testing::AssertionSuccess(); // else // return testing::AssertionFailure() // << "Expected: " << expr << " is even\n Actual: it's " << n; // } // // If Foo() returns 5, you will see the following message: // // Expected: Foo() is even // Actual: it's 5 // class GTEST_API_ AssertionResult { public: // Copy constructor. // Used in EXPECT_TRUE/FALSE(assertion_result). AssertionResult(const AssertionResult& other); #if defined(_MSC_VER) && _MSC_VER < 1910 GTEST_DISABLE_MSC_WARNINGS_PUSH_(4800 /* forcing value to bool */) #endif // Used in the EXPECT_TRUE/FALSE(bool_expression). // // T must be contextually convertible to bool. // // The second parameter prevents this overload from being considered if // the argument is implicitly convertible to AssertionResult. In that case // we want AssertionResult's copy constructor to be used. template explicit AssertionResult( const T& success, typename std::enable_if< !std::is_convertible::value>::type* /*enabler*/ = nullptr) : success_(success) {} #if defined(_MSC_VER) && _MSC_VER < 1910 GTEST_DISABLE_MSC_WARNINGS_POP_() #endif // Assignment operator. AssertionResult& operator=(AssertionResult other) { swap(other); return *this; } // Returns true if and only if the assertion succeeded. operator bool() const { return success_; } // NOLINT // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. AssertionResult operator!() const; // Returns the text streamed into this AssertionResult. Test assertions // use it when they fail (i.e., the predicate's outcome doesn't match the // assertion's expectation). When nothing has been streamed into the // object, returns an empty string. const char* message() const { return message_.get() != nullptr ? message_->c_str() : ""; } // Deprecated; please use message() instead. const char* failure_message() const { return message(); } // Streams a custom failure message into this object. template AssertionResult& operator<<(const T& value) { AppendMessage(Message() << value); return *this; } // Allows streaming basic output manipulators such as endl or flush into // this object. AssertionResult& operator<<( ::std::ostream& (*basic_manipulator)(::std::ostream& stream)) { AppendMessage(Message() << basic_manipulator); return *this; } private: // Appends the contents of message to message_. void AppendMessage(const Message& a_message) { if (message_.get() == nullptr) message_.reset(new ::std::string); message_->append(a_message.GetString().c_str()); } // Swap the contents of this AssertionResult with other. void swap(AssertionResult& other); // Stores result of the assertion predicate. bool success_; // Stores the message describing the condition in case the expectation // construct is not satisfied with the predicate's outcome. // Referenced via a pointer to avoid taking too much stack frame space // with test assertions. std::unique_ptr< ::std::string> message_; }; // Makes a successful assertion result. GTEST_API_ AssertionResult AssertionSuccess(); // Makes a failed assertion result. GTEST_API_ AssertionResult AssertionFailure(); // Makes a failed assertion result with the given failure message. // Deprecated; use AssertionFailure() << msg. GTEST_API_ AssertionResult AssertionFailure(const Message& msg); } // namespace testing // Includes the auto-generated header that implements a family of generic // predicate assertion macros. This include comes late because it relies on // APIs declared above. // Copyright 2006, Google Inc. // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // This file is AUTOMATICALLY GENERATED on 01/02/2019 by command // 'gen_gtest_pred_impl.py 5'. DO NOT EDIT BY HAND! // // Implements a family of generic predicate assertion macros. // GOOGLETEST_CM0001 DO NOT DELETE #ifndef GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ #define GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ namespace testing { // This header implements a family of generic predicate assertion // macros: // // ASSERT_PRED_FORMAT1(pred_format, v1) // ASSERT_PRED_FORMAT2(pred_format, v1, v2) // ... // // where pred_format is a function or functor that takes n (in the // case of ASSERT_PRED_FORMATn) values and their source expression // text, and returns a testing::AssertionResult. See the definition // of ASSERT_EQ in gtest.h for an example. // // If you don't care about formatting, you can use the more // restrictive version: // // ASSERT_PRED1(pred, v1) // ASSERT_PRED2(pred, v1, v2) // ... // // where pred is an n-ary function or functor that returns bool, // and the values v1, v2, ..., must support the << operator for // streaming to std::ostream. // // We also define the EXPECT_* variations. // // For now we only support predicates whose arity is at most 5. // Please email googletestframework@googlegroups.com if you need // support for higher arities. // GTEST_ASSERT_ is the basic statement to which all of the assertions // in this file reduce. Don't use this in your code. #define GTEST_ASSERT_(expression, on_failure) \ GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ if (const ::testing::AssertionResult gtest_ar = (expression)) \ ; \ else \ on_failure(gtest_ar.failure_message()) // Helper function for implementing {EXPECT|ASSERT}_PRED1. Don't use // this in your code. template AssertionResult AssertPred1Helper(const char* pred_text, const char* e1, Pred pred, const T1& v1) { if (pred(v1)) return AssertionSuccess(); return AssertionFailure() << pred_text << "(" << e1 << ") evaluates to false, where" << "\n" << e1 << " evaluates to " << ::testing::PrintToString(v1); } // Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT1. // Don't use this in your code. #define GTEST_PRED_FORMAT1_(pred_format, v1, on_failure)\ GTEST_ASSERT_(pred_format(#v1, v1), \ on_failure) // Internal macro for implementing {EXPECT|ASSERT}_PRED1. Don't use // this in your code. #define GTEST_PRED1_(pred, v1, on_failure)\ GTEST_ASSERT_(::testing::AssertPred1Helper(#pred, \ #v1, \ pred, \ v1), on_failure) // Unary predicate assertion macros. #define EXPECT_PRED_FORMAT1(pred_format, v1) \ GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_NONFATAL_FAILURE_) #define EXPECT_PRED1(pred, v1) \ GTEST_PRED1_(pred, v1, GTEST_NONFATAL_FAILURE_) #define ASSERT_PRED_FORMAT1(pred_format, v1) \ GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_FATAL_FAILURE_) #define ASSERT_PRED1(pred, v1) \ GTEST_PRED1_(pred, v1, GTEST_FATAL_FAILURE_) // Helper function for implementing {EXPECT|ASSERT}_PRED2. Don't use // this in your code. template AssertionResult AssertPred2Helper(const char* pred_text, const char* e1, const char* e2, Pred pred, const T1& v1, const T2& v2) { if (pred(v1, v2)) return AssertionSuccess(); return AssertionFailure() << pred_text << "(" << e1 << ", " << e2 << ") evaluates to false, where" << "\n" << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" << e2 << " evaluates to " << ::testing::PrintToString(v2); } // Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT2. // Don't use this in your code. #define GTEST_PRED_FORMAT2_(pred_format, v1, v2, on_failure)\ GTEST_ASSERT_(pred_format(#v1, #v2, v1, v2), \ on_failure) // Internal macro for implementing {EXPECT|ASSERT}_PRED2. Don't use // this in your code. #define GTEST_PRED2_(pred, v1, v2, on_failure)\ GTEST_ASSERT_(::testing::AssertPred2Helper(#pred, \ #v1, \ #v2, \ pred, \ v1, \ v2), on_failure) // Binary predicate assertion macros. #define EXPECT_PRED_FORMAT2(pred_format, v1, v2) \ GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_NONFATAL_FAILURE_) #define EXPECT_PRED2(pred, v1, v2) \ GTEST_PRED2_(pred, v1, v2, GTEST_NONFATAL_FAILURE_) #define ASSERT_PRED_FORMAT2(pred_format, v1, v2) \ GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_FATAL_FAILURE_) #define ASSERT_PRED2(pred, v1, v2) \ GTEST_PRED2_(pred, v1, v2, GTEST_FATAL_FAILURE_) // Helper function for implementing {EXPECT|ASSERT}_PRED3. Don't use // this in your code. template AssertionResult AssertPred3Helper(const char* pred_text, const char* e1, const char* e2, const char* e3, Pred pred, const T1& v1, const T2& v2, const T3& v3) { if (pred(v1, v2, v3)) return AssertionSuccess(); return AssertionFailure() << pred_text << "(" << e1 << ", " << e2 << ", " << e3 << ") evaluates to false, where" << "\n" << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" << e2 << " evaluates to " << ::testing::PrintToString(v2) << "\n" << e3 << " evaluates to " << ::testing::PrintToString(v3); } // Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT3. // Don't use this in your code. #define GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, on_failure)\ GTEST_ASSERT_(pred_format(#v1, #v2, #v3, v1, v2, v3), \ on_failure) // Internal macro for implementing {EXPECT|ASSERT}_PRED3. Don't use // this in your code. #define GTEST_PRED3_(pred, v1, v2, v3, on_failure)\ GTEST_ASSERT_(::testing::AssertPred3Helper(#pred, \ #v1, \ #v2, \ #v3, \ pred, \ v1, \ v2, \ v3), on_failure) // Ternary predicate assertion macros. #define EXPECT_PRED_FORMAT3(pred_format, v1, v2, v3) \ GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_NONFATAL_FAILURE_) #define EXPECT_PRED3(pred, v1, v2, v3) \ GTEST_PRED3_(pred, v1, v2, v3, GTEST_NONFATAL_FAILURE_) #define ASSERT_PRED_FORMAT3(pred_format, v1, v2, v3) \ GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_FATAL_FAILURE_) #define ASSERT_PRED3(pred, v1, v2, v3) \ GTEST_PRED3_(pred, v1, v2, v3, GTEST_FATAL_FAILURE_) // Helper function for implementing {EXPECT|ASSERT}_PRED4. Don't use // this in your code. template AssertionResult AssertPred4Helper(const char* pred_text, const char* e1, const char* e2, const char* e3, const char* e4, Pred pred, const T1& v1, const T2& v2, const T3& v3, const T4& v4) { if (pred(v1, v2, v3, v4)) return AssertionSuccess(); return AssertionFailure() << pred_text << "(" << e1 << ", " << e2 << ", " << e3 << ", " << e4 << ") evaluates to false, where" << "\n" << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" << e2 << " evaluates to " << ::testing::PrintToString(v2) << "\n" << e3 << " evaluates to " << ::testing::PrintToString(v3) << "\n" << e4 << " evaluates to " << ::testing::PrintToString(v4); } // Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT4. // Don't use this in your code. #define GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, on_failure)\ GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, v1, v2, v3, v4), \ on_failure) // Internal macro for implementing {EXPECT|ASSERT}_PRED4. Don't use // this in your code. #define GTEST_PRED4_(pred, v1, v2, v3, v4, on_failure)\ GTEST_ASSERT_(::testing::AssertPred4Helper(#pred, \ #v1, \ #v2, \ #v3, \ #v4, \ pred, \ v1, \ v2, \ v3, \ v4), on_failure) // 4-ary predicate assertion macros. #define EXPECT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) #define EXPECT_PRED4(pred, v1, v2, v3, v4) \ GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) #define ASSERT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) #define ASSERT_PRED4(pred, v1, v2, v3, v4) \ GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) // Helper function for implementing {EXPECT|ASSERT}_PRED5. Don't use // this in your code. template AssertionResult AssertPred5Helper(const char* pred_text, const char* e1, const char* e2, const char* e3, const char* e4, const char* e5, Pred pred, const T1& v1, const T2& v2, const T3& v3, const T4& v4, const T5& v5) { if (pred(v1, v2, v3, v4, v5)) return AssertionSuccess(); return AssertionFailure() << pred_text << "(" << e1 << ", " << e2 << ", " << e3 << ", " << e4 << ", " << e5 << ") evaluates to false, where" << "\n" << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" << e2 << " evaluates to " << ::testing::PrintToString(v2) << "\n" << e3 << " evaluates to " << ::testing::PrintToString(v3) << "\n" << e4 << " evaluates to " << ::testing::PrintToString(v4) << "\n" << e5 << " evaluates to " << ::testing::PrintToString(v5); } // Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT5. // Don't use this in your code. #define GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, on_failure)\ GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, #v5, v1, v2, v3, v4, v5), \ on_failure) // Internal macro for implementing {EXPECT|ASSERT}_PRED5. Don't use // this in your code. #define GTEST_PRED5_(pred, v1, v2, v3, v4, v5, on_failure)\ GTEST_ASSERT_(::testing::AssertPred5Helper(#pred, \ #v1, \ #v2, \ #v3, \ #v4, \ #v5, \ pred, \ v1, \ v2, \ v3, \ v4, \ v5), on_failure) // 5-ary predicate assertion macros. #define EXPECT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) #define EXPECT_PRED5(pred, v1, v2, v3, v4, v5) \ GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) #define ASSERT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) #define ASSERT_PRED5(pred, v1, v2, v3, v4, v5) \ GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) } // namespace testing #endif // GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ namespace testing { // The abstract class that all tests inherit from. // // In Google Test, a unit test program contains one or many TestSuites, and // each TestSuite contains one or many Tests. // // When you define a test using the TEST macro, you don't need to // explicitly derive from Test - the TEST macro automatically does // this for you. // // The only time you derive from Test is when defining a test fixture // to be used in a TEST_F. For example: // // class FooTest : public testing::Test { // protected: // void SetUp() override { ... } // void TearDown() override { ... } // ... // }; // // TEST_F(FooTest, Bar) { ... } // TEST_F(FooTest, Baz) { ... } // // Test is not copyable. class GTEST_API_ Test { public: friend class TestInfo; // The d'tor is virtual as we intend to inherit from Test. virtual ~Test(); // Sets up the stuff shared by all tests in this test case. // // Google Test will call Foo::SetUpTestSuite() before running the first // test in test case Foo. Hence a sub-class can define its own // SetUpTestSuite() method to shadow the one defined in the super // class. // Failures that happen during SetUpTestSuite are logged but otherwise // ignored. static void SetUpTestSuite() {} // Tears down the stuff shared by all tests in this test suite. // // Google Test will call Foo::TearDownTestSuite() after running the last // test in test case Foo. Hence a sub-class can define its own // TearDownTestSuite() method to shadow the one defined in the super // class. // Failures that happen during TearDownTestSuite are logged but otherwise // ignored. static void TearDownTestSuite() {} // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ static void TearDownTestCase() {} static void SetUpTestCase() {} #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Returns true if and only if the current test has a fatal failure. static bool HasFatalFailure(); // Returns true if and only if the current test has a non-fatal failure. static bool HasNonfatalFailure(); // Returns true if and only if the current test was skipped. static bool IsSkipped(); // Returns true if and only if the current test has a (either fatal or // non-fatal) failure. static bool HasFailure() { return HasFatalFailure() || HasNonfatalFailure(); } // Logs a property for the current test, test suite, or for the entire // invocation of the test program when used outside of the context of a // test suite. Only the last value for a given key is remembered. These // are public static so they can be called from utility functions that are // not members of the test fixture. Calls to RecordProperty made during // lifespan of the test (from the moment its constructor starts to the // moment its destructor finishes) will be output in XML as attributes of // the element. Properties recorded from fixture's // SetUpTestSuite or TearDownTestSuite are logged as attributes of the // corresponding element. Calls to RecordProperty made in the // global context (before or after invocation of RUN_ALL_TESTS and from // SetUp/TearDown method of Environment objects registered with Google // Test) will be output as attributes of the element. static void RecordProperty(const std::string& key, const std::string& value); static void RecordProperty(const std::string& key, int value); protected: // Creates a Test object. Test(); // Sets up the test fixture. virtual void SetUp(); // Tears down the test fixture. virtual void TearDown(); private: // Returns true if and only if the current test has the same fixture class // as the first test in the current test suite. static bool HasSameFixtureClass(); // Runs the test after the test fixture has been set up. // // A sub-class must implement this to define the test logic. // // DO NOT OVERRIDE THIS FUNCTION DIRECTLY IN A USER PROGRAM. // Instead, use the TEST or TEST_F macro. virtual void TestBody() = 0; // Sets up, executes, and tears down the test. void Run(); // Deletes self. We deliberately pick an unusual name for this // internal method to avoid clashing with names used in user TESTs. void DeleteSelf_() { delete this; } const std::unique_ptr gtest_flag_saver_; // Often a user misspells SetUp() as Setup() and spends a long time // wondering why it is never called by Google Test. The declaration of // the following method is solely for catching such an error at // compile time: // // - The return type is deliberately chosen to be not void, so it // will be a conflict if void Setup() is declared in the user's // test fixture. // // - This method is private, so it will be another compiler error // if the method is called from the user's test fixture. // // DO NOT OVERRIDE THIS FUNCTION. // // If you see an error about overriding the following function or // about it being private, you have mis-spelled SetUp() as Setup(). struct Setup_should_be_spelled_SetUp {}; virtual Setup_should_be_spelled_SetUp* Setup() { return nullptr; } // We disallow copying Tests. GTEST_DISALLOW_COPY_AND_ASSIGN_(Test); }; typedef internal::TimeInMillis TimeInMillis; // A copyable object representing a user specified test property which can be // output as a key/value string pair. // // Don't inherit from TestProperty as its destructor is not virtual. class TestProperty { public: // C'tor. TestProperty does NOT have a default constructor. // Always use this constructor (with parameters) to create a // TestProperty object. TestProperty(const std::string& a_key, const std::string& a_value) : key_(a_key), value_(a_value) { } // Gets the user supplied key. const char* key() const { return key_.c_str(); } // Gets the user supplied value. const char* value() const { return value_.c_str(); } // Sets a new value, overriding the one supplied in the constructor. void SetValue(const std::string& new_value) { value_ = new_value; } private: // The key supplied by the user. std::string key_; // The value supplied by the user. std::string value_; }; // The result of a single Test. This includes a list of // TestPartResults, a list of TestProperties, a count of how many // death tests there are in the Test, and how much time it took to run // the Test. // // TestResult is not copyable. class GTEST_API_ TestResult { public: // Creates an empty TestResult. TestResult(); // D'tor. Do not inherit from TestResult. ~TestResult(); // Gets the number of all test parts. This is the sum of the number // of successful test parts and the number of failed test parts. int total_part_count() const; // Returns the number of the test properties. int test_property_count() const; // Returns true if and only if the test passed (i.e. no test part failed). bool Passed() const { return !Skipped() && !Failed(); } // Returns true if and only if the test was skipped. bool Skipped() const; // Returns true if and only if the test failed. bool Failed() const; // Returns true if and only if the test fatally failed. bool HasFatalFailure() const; // Returns true if and only if the test has a non-fatal failure. bool HasNonfatalFailure() const; // Returns the elapsed time, in milliseconds. TimeInMillis elapsed_time() const { return elapsed_time_; } // Gets the time of the test case start, in ms from the start of the // UNIX epoch. TimeInMillis start_timestamp() const { return start_timestamp_; } // Returns the i-th test part result among all the results. i can range from 0 // to total_part_count() - 1. If i is not in that range, aborts the program. const TestPartResult& GetTestPartResult(int i) const; // Returns the i-th test property. i can range from 0 to // test_property_count() - 1. If i is not in that range, aborts the // program. const TestProperty& GetTestProperty(int i) const; private: friend class TestInfo; friend class TestSuite; friend class UnitTest; friend class internal::DefaultGlobalTestPartResultReporter; friend class internal::ExecDeathTest; friend class internal::TestResultAccessor; friend class internal::UnitTestImpl; friend class internal::WindowsDeathTest; friend class internal::FuchsiaDeathTest; // Gets the vector of TestPartResults. const std::vector& test_part_results() const { return test_part_results_; } // Gets the vector of TestProperties. const std::vector& test_properties() const { return test_properties_; } // Sets the start time. void set_start_timestamp(TimeInMillis start) { start_timestamp_ = start; } // Sets the elapsed time. void set_elapsed_time(TimeInMillis elapsed) { elapsed_time_ = elapsed; } // Adds a test property to the list. The property is validated and may add // a non-fatal failure if invalid (e.g., if it conflicts with reserved // key names). If a property is already recorded for the same key, the // value will be updated, rather than storing multiple values for the same // key. xml_element specifies the element for which the property is being // recorded and is used for validation. void RecordProperty(const std::string& xml_element, const TestProperty& test_property); // Adds a failure if the key is a reserved attribute of Google Test // testsuite tags. Returns true if the property is valid. // FIXME: Validate attribute names are legal and human readable. static bool ValidateTestProperty(const std::string& xml_element, const TestProperty& test_property); // Adds a test part result to the list. void AddTestPartResult(const TestPartResult& test_part_result); // Returns the death test count. int death_test_count() const { return death_test_count_; } // Increments the death test count, returning the new count. int increment_death_test_count() { return ++death_test_count_; } // Clears the test part results. void ClearTestPartResults(); // Clears the object. void Clear(); // Protects mutable state of the property vector and of owned // properties, whose values may be updated. internal::Mutex test_properites_mutex_; // The vector of TestPartResults std::vector test_part_results_; // The vector of TestProperties std::vector test_properties_; // Running count of death tests. int death_test_count_; // The start time, in milliseconds since UNIX Epoch. TimeInMillis start_timestamp_; // The elapsed time, in milliseconds. TimeInMillis elapsed_time_; // We disallow copying TestResult. GTEST_DISALLOW_COPY_AND_ASSIGN_(TestResult); }; // class TestResult // A TestInfo object stores the following information about a test: // // Test suite name // Test name // Whether the test should be run // A function pointer that creates the test object when invoked // Test result // // The constructor of TestInfo registers itself with the UnitTest // singleton such that the RUN_ALL_TESTS() macro knows which tests to // run. class GTEST_API_ TestInfo { public: // Destructs a TestInfo object. This function is not virtual, so // don't inherit from TestInfo. ~TestInfo(); // Returns the test suite name. const char* test_suite_name() const { return test_suite_name_.c_str(); } // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ const char* test_case_name() const { return test_suite_name(); } #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Returns the test name. const char* name() const { return name_.c_str(); } // Returns the name of the parameter type, or NULL if this is not a typed // or a type-parameterized test. const char* type_param() const { if (type_param_.get() != nullptr) return type_param_->c_str(); return nullptr; } // Returns the text representation of the value parameter, or NULL if this // is not a value-parameterized test. const char* value_param() const { if (value_param_.get() != nullptr) return value_param_->c_str(); return nullptr; } // Returns the file name where this test is defined. const char* file() const { return location_.file.c_str(); } // Returns the line where this test is defined. int line() const { return location_.line; } // Return true if this test should not be run because it's in another shard. bool is_in_another_shard() const { return is_in_another_shard_; } // Returns true if this test should run, that is if the test is not // disabled (or it is disabled but the also_run_disabled_tests flag has // been specified) and its full name matches the user-specified filter. // // Google Test allows the user to filter the tests by their full names. // The full name of a test Bar in test suite Foo is defined as // "Foo.Bar". Only the tests that match the filter will run. // // A filter is a colon-separated list of glob (not regex) patterns, // optionally followed by a '-' and a colon-separated list of // negative patterns (tests to exclude). A test is run if it // matches one of the positive patterns and does not match any of // the negative patterns. // // For example, *A*:Foo.* is a filter that matches any string that // contains the character 'A' or starts with "Foo.". bool should_run() const { return should_run_; } // Returns true if and only if this test will appear in the XML report. bool is_reportable() const { // The XML report includes tests matching the filter, excluding those // run in other shards. return matches_filter_ && !is_in_another_shard_; } // Returns the result of the test. const TestResult* result() const { return &result_; } private: #if GTEST_HAS_DEATH_TEST friend class internal::DefaultDeathTestFactory; #endif // GTEST_HAS_DEATH_TEST friend class Test; friend class TestSuite; friend class internal::UnitTestImpl; friend class internal::StreamingListenerTest; friend TestInfo* internal::MakeAndRegisterTestInfo( const char* test_suite_name, const char* name, const char* type_param, const char* value_param, internal::CodeLocation code_location, internal::TypeId fixture_class_id, internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc, internal::TestFactoryBase* factory); // Constructs a TestInfo object. The newly constructed instance assumes // ownership of the factory object. TestInfo(const std::string& test_suite_name, const std::string& name, const char* a_type_param, // NULL if not a type-parameterized test const char* a_value_param, // NULL if not a value-parameterized test internal::CodeLocation a_code_location, internal::TypeId fixture_class_id, internal::TestFactoryBase* factory); // Increments the number of death tests encountered in this test so // far. int increment_death_test_count() { return result_.increment_death_test_count(); } // Creates the test object, runs it, records its result, and then // deletes it. void Run(); static void ClearTestResult(TestInfo* test_info) { test_info->result_.Clear(); } // These fields are immutable properties of the test. const std::string test_suite_name_; // test suite name const std::string name_; // Test name // Name of the parameter type, or NULL if this is not a typed or a // type-parameterized test. const std::unique_ptr type_param_; // Text representation of the value parameter, or NULL if this is not a // value-parameterized test. const std::unique_ptr value_param_; internal::CodeLocation location_; const internal::TypeId fixture_class_id_; // ID of the test fixture class bool should_run_; // True if and only if this test should run bool is_disabled_; // True if and only if this test is disabled bool matches_filter_; // True if this test matches the // user-specified filter. bool is_in_another_shard_; // Will be run in another shard. internal::TestFactoryBase* const factory_; // The factory that creates // the test object // This field is mutable and needs to be reset before running the // test for the second time. TestResult result_; GTEST_DISALLOW_COPY_AND_ASSIGN_(TestInfo); }; // A test suite, which consists of a vector of TestInfos. // // TestSuite is not copyable. class GTEST_API_ TestSuite { public: // Creates a TestSuite with the given name. // // TestSuite does NOT have a default constructor. Always use this // constructor to create a TestSuite object. // // Arguments: // // name: name of the test suite // a_type_param: the name of the test's type parameter, or NULL if // this is not a type-parameterized test. // set_up_tc: pointer to the function that sets up the test suite // tear_down_tc: pointer to the function that tears down the test suite TestSuite(const char* name, const char* a_type_param, internal::SetUpTestSuiteFunc set_up_tc, internal::TearDownTestSuiteFunc tear_down_tc); // Destructor of TestSuite. virtual ~TestSuite(); // Gets the name of the TestSuite. const char* name() const { return name_.c_str(); } // Returns the name of the parameter type, or NULL if this is not a // type-parameterized test suite. const char* type_param() const { if (type_param_.get() != nullptr) return type_param_->c_str(); return nullptr; } // Returns true if any test in this test suite should run. bool should_run() const { return should_run_; } // Gets the number of successful tests in this test suite. int successful_test_count() const; // Gets the number of skipped tests in this test suite. int skipped_test_count() const; // Gets the number of failed tests in this test suite. int failed_test_count() const; // Gets the number of disabled tests that will be reported in the XML report. int reportable_disabled_test_count() const; // Gets the number of disabled tests in this test suite. int disabled_test_count() const; // Gets the number of tests to be printed in the XML report. int reportable_test_count() const; // Get the number of tests in this test suite that should run. int test_to_run_count() const; // Gets the number of all tests in this test suite. int total_test_count() const; // Returns true if and only if the test suite passed. bool Passed() const { return !Failed(); } // Returns true if and only if the test suite failed. bool Failed() const { return failed_test_count() > 0; } // Returns the elapsed time, in milliseconds. TimeInMillis elapsed_time() const { return elapsed_time_; } // Gets the time of the test suite start, in ms from the start of the // UNIX epoch. TimeInMillis start_timestamp() const { return start_timestamp_; } // Returns the i-th test among all the tests. i can range from 0 to // total_test_count() - 1. If i is not in that range, returns NULL. const TestInfo* GetTestInfo(int i) const; // Returns the TestResult that holds test properties recorded during // execution of SetUpTestSuite and TearDownTestSuite. const TestResult& ad_hoc_test_result() const { return ad_hoc_test_result_; } private: friend class Test; friend class internal::UnitTestImpl; // Gets the (mutable) vector of TestInfos in this TestSuite. std::vector& test_info_list() { return test_info_list_; } // Gets the (immutable) vector of TestInfos in this TestSuite. const std::vector& test_info_list() const { return test_info_list_; } // Returns the i-th test among all the tests. i can range from 0 to // total_test_count() - 1. If i is not in that range, returns NULL. TestInfo* GetMutableTestInfo(int i); // Sets the should_run member. void set_should_run(bool should) { should_run_ = should; } // Adds a TestInfo to this test suite. Will delete the TestInfo upon // destruction of the TestSuite object. void AddTestInfo(TestInfo * test_info); // Clears the results of all tests in this test suite. void ClearResult(); // Clears the results of all tests in the given test suite. static void ClearTestSuiteResult(TestSuite* test_suite) { test_suite->ClearResult(); } // Runs every test in this TestSuite. void Run(); // Runs SetUpTestSuite() for this TestSuite. This wrapper is needed // for catching exceptions thrown from SetUpTestSuite(). void RunSetUpTestSuite() { if (set_up_tc_ != nullptr) { (*set_up_tc_)(); } } // Runs TearDownTestSuite() for this TestSuite. This wrapper is // needed for catching exceptions thrown from TearDownTestSuite(). void RunTearDownTestSuite() { if (tear_down_tc_ != nullptr) { (*tear_down_tc_)(); } } // Returns true if and only if test passed. static bool TestPassed(const TestInfo* test_info) { return test_info->should_run() && test_info->result()->Passed(); } // Returns true if and only if test skipped. static bool TestSkipped(const TestInfo* test_info) { return test_info->should_run() && test_info->result()->Skipped(); } // Returns true if and only if test failed. static bool TestFailed(const TestInfo* test_info) { return test_info->should_run() && test_info->result()->Failed(); } // Returns true if and only if the test is disabled and will be reported in // the XML report. static bool TestReportableDisabled(const TestInfo* test_info) { return test_info->is_reportable() && test_info->is_disabled_; } // Returns true if and only if test is disabled. static bool TestDisabled(const TestInfo* test_info) { return test_info->is_disabled_; } // Returns true if and only if this test will appear in the XML report. static bool TestReportable(const TestInfo* test_info) { return test_info->is_reportable(); } // Returns true if the given test should run. static bool ShouldRunTest(const TestInfo* test_info) { return test_info->should_run(); } // Shuffles the tests in this test suite. void ShuffleTests(internal::Random* random); // Restores the test order to before the first shuffle. void UnshuffleTests(); // Name of the test suite. std::string name_; // Name of the parameter type, or NULL if this is not a typed or a // type-parameterized test. const std::unique_ptr type_param_; // The vector of TestInfos in their original order. It owns the // elements in the vector. std::vector test_info_list_; // Provides a level of indirection for the test list to allow easy // shuffling and restoring the test order. The i-th element in this // vector is the index of the i-th test in the shuffled test list. std::vector test_indices_; // Pointer to the function that sets up the test suite. internal::SetUpTestSuiteFunc set_up_tc_; // Pointer to the function that tears down the test suite. internal::TearDownTestSuiteFunc tear_down_tc_; // True if and only if any test in this test suite should run. bool should_run_; // The start time, in milliseconds since UNIX Epoch. TimeInMillis start_timestamp_; // Elapsed time, in milliseconds. TimeInMillis elapsed_time_; // Holds test properties recorded during execution of SetUpTestSuite and // TearDownTestSuite. TestResult ad_hoc_test_result_; // We disallow copying TestSuites. GTEST_DISALLOW_COPY_AND_ASSIGN_(TestSuite); }; // An Environment object is capable of setting up and tearing down an // environment. You should subclass this to define your own // environment(s). // // An Environment object does the set-up and tear-down in virtual // methods SetUp() and TearDown() instead of the constructor and the // destructor, as: // // 1. You cannot safely throw from a destructor. This is a problem // as in some cases Google Test is used where exceptions are enabled, and // we may want to implement ASSERT_* using exceptions where they are // available. // 2. You cannot use ASSERT_* directly in a constructor or // destructor. class Environment { public: // The d'tor is virtual as we need to subclass Environment. virtual ~Environment() {} // Override this to define how to set up the environment. virtual void SetUp() {} // Override this to define how to tear down the environment. virtual void TearDown() {} private: // If you see an error about overriding the following function or // about it being private, you have mis-spelled SetUp() as Setup(). struct Setup_should_be_spelled_SetUp {}; virtual Setup_should_be_spelled_SetUp* Setup() { return nullptr; } }; #if GTEST_HAS_EXCEPTIONS // Exception which can be thrown from TestEventListener::OnTestPartResult. class GTEST_API_ AssertionException : public internal::GoogleTestFailureException { public: explicit AssertionException(const TestPartResult& result) : GoogleTestFailureException(result) {} }; #endif // GTEST_HAS_EXCEPTIONS // The interface for tracing execution of tests. The methods are organized in // the order the corresponding events are fired. class TestEventListener { public: virtual ~TestEventListener() {} // Fired before any test activity starts. virtual void OnTestProgramStart(const UnitTest& unit_test) = 0; // Fired before each iteration of tests starts. There may be more than // one iteration if GTEST_FLAG(repeat) is set. iteration is the iteration // index, starting from 0. virtual void OnTestIterationStart(const UnitTest& unit_test, int iteration) = 0; // Fired before environment set-up for each iteration of tests starts. virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test) = 0; // Fired after environment set-up for each iteration of tests ends. virtual void OnEnvironmentsSetUpEnd(const UnitTest& unit_test) = 0; // Fired before the test suite starts. virtual void OnTestSuiteStart(const TestSuite& /*test_suite*/) {} // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ virtual void OnTestCaseStart(const TestCase& /*test_case*/) {} #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Fired before the test starts. virtual void OnTestStart(const TestInfo& test_info) = 0; // Fired after a failed assertion or a SUCCEED() invocation. // If you want to throw an exception from this function to skip to the next // TEST, it must be AssertionException defined above, or inherited from it. virtual void OnTestPartResult(const TestPartResult& test_part_result) = 0; // Fired after the test ends. virtual void OnTestEnd(const TestInfo& test_info) = 0; // Fired after the test suite ends. virtual void OnTestSuiteEnd(const TestSuite& /*test_suite*/) {} // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ virtual void OnTestCaseEnd(const TestCase& /*test_case*/) {} #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Fired before environment tear-down for each iteration of tests starts. virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test) = 0; // Fired after environment tear-down for each iteration of tests ends. virtual void OnEnvironmentsTearDownEnd(const UnitTest& unit_test) = 0; // Fired after each iteration of tests finishes. virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration) = 0; // Fired after all test activities have ended. virtual void OnTestProgramEnd(const UnitTest& unit_test) = 0; }; // The convenience class for users who need to override just one or two // methods and are not concerned that a possible change to a signature of // the methods they override will not be caught during the build. For // comments about each method please see the definition of TestEventListener // above. class EmptyTestEventListener : public TestEventListener { public: void OnTestProgramStart(const UnitTest& /*unit_test*/) override {} void OnTestIterationStart(const UnitTest& /*unit_test*/, int /*iteration*/) override {} void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) override {} void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override {} void OnTestSuiteStart(const TestSuite& /*test_suite*/) override {} // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestCaseStart(const TestCase& /*test_case*/) override {} #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestStart(const TestInfo& /*test_info*/) override {} void OnTestPartResult(const TestPartResult& /*test_part_result*/) override {} void OnTestEnd(const TestInfo& /*test_info*/) override {} void OnTestSuiteEnd(const TestSuite& /*test_suite*/) override {} #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnTestCaseEnd(const TestCase& /*test_case*/) override {} #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) override {} void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override {} void OnTestIterationEnd(const UnitTest& /*unit_test*/, int /*iteration*/) override {} void OnTestProgramEnd(const UnitTest& /*unit_test*/) override {} }; // TestEventListeners lets users add listeners to track events in Google Test. class GTEST_API_ TestEventListeners { public: TestEventListeners(); ~TestEventListeners(); // Appends an event listener to the end of the list. Google Test assumes // the ownership of the listener (i.e. it will delete the listener when // the test program finishes). void Append(TestEventListener* listener); // Removes the given event listener from the list and returns it. It then // becomes the caller's responsibility to delete the listener. Returns // NULL if the listener is not found in the list. TestEventListener* Release(TestEventListener* listener); // Returns the standard listener responsible for the default console // output. Can be removed from the listeners list to shut down default // console output. Note that removing this object from the listener list // with Release transfers its ownership to the caller and makes this // function return NULL the next time. TestEventListener* default_result_printer() const { return default_result_printer_; } // Returns the standard listener responsible for the default XML output // controlled by the --gtest_output=xml flag. Can be removed from the // listeners list by users who want to shut down the default XML output // controlled by this flag and substitute it with custom one. Note that // removing this object from the listener list with Release transfers its // ownership to the caller and makes this function return NULL the next // time. TestEventListener* default_xml_generator() const { return default_xml_generator_; } private: friend class TestSuite; friend class TestInfo; friend class internal::DefaultGlobalTestPartResultReporter; friend class internal::NoExecDeathTest; friend class internal::TestEventListenersAccessor; friend class internal::UnitTestImpl; // Returns repeater that broadcasts the TestEventListener events to all // subscribers. TestEventListener* repeater(); // Sets the default_result_printer attribute to the provided listener. // The listener is also added to the listener list and previous // default_result_printer is removed from it and deleted. The listener can // also be NULL in which case it will not be added to the list. Does // nothing if the previous and the current listener objects are the same. void SetDefaultResultPrinter(TestEventListener* listener); // Sets the default_xml_generator attribute to the provided listener. The // listener is also added to the listener list and previous // default_xml_generator is removed from it and deleted. The listener can // also be NULL in which case it will not be added to the list. Does // nothing if the previous and the current listener objects are the same. void SetDefaultXmlGenerator(TestEventListener* listener); // Controls whether events will be forwarded by the repeater to the // listeners in the list. bool EventForwardingEnabled() const; void SuppressEventForwarding(); // The actual list of listeners. internal::TestEventRepeater* repeater_; // Listener responsible for the standard result output. TestEventListener* default_result_printer_; // Listener responsible for the creation of the XML output file. TestEventListener* default_xml_generator_; // We disallow copying TestEventListeners. GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventListeners); }; // A UnitTest consists of a vector of TestSuites. // // This is a singleton class. The only instance of UnitTest is // created when UnitTest::GetInstance() is first called. This // instance is never deleted. // // UnitTest is not copyable. // // This class is thread-safe as long as the methods are called // according to their specification. class GTEST_API_ UnitTest { public: // Gets the singleton UnitTest object. The first time this method // is called, a UnitTest object is constructed and returned. // Consecutive calls will return the same object. static UnitTest* GetInstance(); // Runs all tests in this UnitTest object and prints the result. // Returns 0 if successful, or 1 otherwise. // // This method can only be called from the main thread. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. int Run() GTEST_MUST_USE_RESULT_; // Returns the working directory when the first TEST() or TEST_F() // was executed. The UnitTest object owns the string. const char* original_working_dir() const; // Returns the TestSuite object for the test that's currently running, // or NULL if no test is running. const TestSuite* current_test_suite() const GTEST_LOCK_EXCLUDED_(mutex_); // Legacy API is still available but deprecated #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ const TestCase* current_test_case() const GTEST_LOCK_EXCLUDED_(mutex_); #endif // Returns the TestInfo object for the test that's currently running, // or NULL if no test is running. const TestInfo* current_test_info() const GTEST_LOCK_EXCLUDED_(mutex_); // Returns the random seed used at the start of the current test run. int random_seed() const; // Returns the ParameterizedTestSuiteRegistry object used to keep track of // value-parameterized tests and instantiate and register them. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. internal::ParameterizedTestSuiteRegistry& parameterized_test_registry() GTEST_LOCK_EXCLUDED_(mutex_); // Gets the number of successful test suites. int successful_test_suite_count() const; // Gets the number of failed test suites. int failed_test_suite_count() const; // Gets the number of all test suites. int total_test_suite_count() const; // Gets the number of all test suites that contain at least one test // that should run. int test_suite_to_run_count() const; // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ int successful_test_case_count() const; int failed_test_case_count() const; int total_test_case_count() const; int test_case_to_run_count() const; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Gets the number of successful tests. int successful_test_count() const; // Gets the number of skipped tests. int skipped_test_count() const; // Gets the number of failed tests. int failed_test_count() const; // Gets the number of disabled tests that will be reported in the XML report. int reportable_disabled_test_count() const; // Gets the number of disabled tests. int disabled_test_count() const; // Gets the number of tests to be printed in the XML report. int reportable_test_count() const; // Gets the number of all tests. int total_test_count() const; // Gets the number of tests that should run. int test_to_run_count() const; // Gets the time of the test program start, in ms from the start of the // UNIX epoch. TimeInMillis start_timestamp() const; // Gets the elapsed time, in milliseconds. TimeInMillis elapsed_time() const; // Returns true if and only if the unit test passed (i.e. all test suites // passed). bool Passed() const; // Returns true if and only if the unit test failed (i.e. some test suite // failed or something outside of all tests failed). bool Failed() const; // Gets the i-th test suite among all the test suites. i can range from 0 to // total_test_suite_count() - 1. If i is not in that range, returns NULL. const TestSuite* GetTestSuite(int i) const; // Legacy API is deprecated but still available #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ const TestCase* GetTestCase(int i) const; #endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ // Returns the TestResult containing information on test failures and // properties logged outside of individual test suites. const TestResult& ad_hoc_test_result() const; // Returns the list of event listeners that can be used to track events // inside Google Test. TestEventListeners& listeners(); private: // Registers and returns a global test environment. When a test // program is run, all global test environments will be set-up in // the order they were registered. After all tests in the program // have finished, all global test environments will be torn-down in // the *reverse* order they were registered. // // The UnitTest object takes ownership of the given environment. // // This method can only be called from the main thread. Environment* AddEnvironment(Environment* env); // Adds a TestPartResult to the current TestResult object. All // Google Test assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) // eventually call this to report their results. The user code // should use the assertion macros instead of calling this directly. void AddTestPartResult(TestPartResult::Type result_type, const char* file_name, int line_number, const std::string& message, const std::string& os_stack_trace) GTEST_LOCK_EXCLUDED_(mutex_); // Adds a TestProperty to the current TestResult object when invoked from // inside a test, to current TestSuite's ad_hoc_test_result_ when invoked // from SetUpTestSuite or TearDownTestSuite, or to the global property set // when invoked elsewhere. If the result already contains a property with // the same key, the value will be updated. void RecordProperty(const std::string& key, const std::string& value); // Gets the i-th test suite among all the test suites. i can range from 0 to // total_test_suite_count() - 1. If i is not in that range, returns NULL. TestSuite* GetMutableTestSuite(int i); // Accessors for the implementation object. internal::UnitTestImpl* impl() { return impl_; } const internal::UnitTestImpl* impl() const { return impl_; } // These classes and functions are friends as they need to access private // members of UnitTest. friend class ScopedTrace; friend class Test; friend class internal::AssertHelper; friend class internal::StreamingListenerTest; friend class internal::UnitTestRecordPropertyTestHelper; friend Environment* AddGlobalTestEnvironment(Environment* env); friend internal::UnitTestImpl* internal::GetUnitTestImpl(); friend void internal::ReportFailureInUnknownLocation( TestPartResult::Type result_type, const std::string& message); // Creates an empty UnitTest. UnitTest(); // D'tor virtual ~UnitTest(); // Pushes a trace defined by SCOPED_TRACE() on to the per-thread // Google Test trace stack. void PushGTestTrace(const internal::TraceInfo& trace) GTEST_LOCK_EXCLUDED_(mutex_); // Pops a trace from the per-thread Google Test trace stack. void PopGTestTrace() GTEST_LOCK_EXCLUDED_(mutex_); // Protects mutable state in *impl_. This is mutable as some const // methods need to lock it too. mutable internal::Mutex mutex_; // Opaque implementation object. This field is never changed once // the object is constructed. We don't mark it as const here, as // doing so will cause a warning in the constructor of UnitTest. // Mutable state in *impl_ is protected by mutex_. internal::UnitTestImpl* impl_; // We disallow copying UnitTest. GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTest); }; // A convenient wrapper for adding an environment for the test // program. // // You should call this before RUN_ALL_TESTS() is called, probably in // main(). If you use gtest_main, you need to call this before main() // starts for it to take effect. For example, you can define a global // variable like this: // // testing::Environment* const foo_env = // testing::AddGlobalTestEnvironment(new FooEnvironment); // // However, we strongly recommend you to write your own main() and // call AddGlobalTestEnvironment() there, as relying on initialization // of global variables makes the code harder to read and may cause // problems when you register multiple environments from different // translation units and the environments have dependencies among them // (remember that the compiler doesn't guarantee the order in which // global variables from different translation units are initialized). inline Environment* AddGlobalTestEnvironment(Environment* env) { return UnitTest::GetInstance()->AddEnvironment(env); } // Initializes Google Test. This must be called before calling // RUN_ALL_TESTS(). In particular, it parses a command line for the // flags that Google Test recognizes. Whenever a Google Test flag is // seen, it is removed from argv, and *argc is decremented. // // No value is returned. Instead, the Google Test flag variables are // updated. // // Calling the function for the second time has no user-visible effect. GTEST_API_ void InitGoogleTest(int* argc, char** argv); // This overloaded version can be used in Windows programs compiled in // UNICODE mode. GTEST_API_ void InitGoogleTest(int* argc, wchar_t** argv); // This overloaded version can be used on Arduino/embedded platforms where // there is no argc/argv. GTEST_API_ void InitGoogleTest(); namespace internal { // Separate the error generating code from the code path to reduce the stack // frame size of CmpHelperEQ. This helps reduce the overhead of some sanitizers // when calling EXPECT_* in a tight loop. template AssertionResult CmpHelperEQFailure(const char* lhs_expression, const char* rhs_expression, const T1& lhs, const T2& rhs) { return EqFailure(lhs_expression, rhs_expression, FormatForComparisonFailureMessage(lhs, rhs), FormatForComparisonFailureMessage(rhs, lhs), false); } // This block of code defines operator==/!= // to block lexical scope lookup. // It prevents using invalid operator==/!= defined at namespace scope. struct faketype {}; inline bool operator==(faketype, faketype) { return true; } inline bool operator!=(faketype, faketype) { return false; } // The helper function for {ASSERT|EXPECT}_EQ. template AssertionResult CmpHelperEQ(const char* lhs_expression, const char* rhs_expression, const T1& lhs, const T2& rhs) { if (lhs == rhs) { return AssertionSuccess(); } return CmpHelperEQFailure(lhs_expression, rhs_expression, lhs, rhs); } // With this overloaded version, we allow anonymous enums to be used // in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous enums // can be implicitly cast to BiggestInt. GTEST_API_ AssertionResult CmpHelperEQ(const char* lhs_expression, const char* rhs_expression, BiggestInt lhs, BiggestInt rhs); class EqHelper { public: // This templatized version is for the general case. template < typename T1, typename T2, // Disable this overload for cases where one argument is a pointer // and the other is the null pointer constant. typename std::enable_if::value || !std::is_pointer::value>::type* = nullptr> static AssertionResult Compare(const char* lhs_expression, const char* rhs_expression, const T1& lhs, const T2& rhs) { return CmpHelperEQ(lhs_expression, rhs_expression, lhs, rhs); } // With this overloaded version, we allow anonymous enums to be used // in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous // enums can be implicitly cast to BiggestInt. // // Even though its body looks the same as the above version, we // cannot merge the two, as it will make anonymous enums unhappy. static AssertionResult Compare(const char* lhs_expression, const char* rhs_expression, BiggestInt lhs, BiggestInt rhs) { return CmpHelperEQ(lhs_expression, rhs_expression, lhs, rhs); } template static AssertionResult Compare( const char* lhs_expression, const char* rhs_expression, // Handle cases where '0' is used as a null pointer literal. std::nullptr_t /* lhs */, T* rhs) { // We already know that 'lhs' is a null pointer. return CmpHelperEQ(lhs_expression, rhs_expression, static_cast(nullptr), rhs); } }; // Separate the error generating code from the code path to reduce the stack // frame size of CmpHelperOP. This helps reduce the overhead of some sanitizers // when calling EXPECT_OP in a tight loop. template AssertionResult CmpHelperOpFailure(const char* expr1, const char* expr2, const T1& val1, const T2& val2, const char* op) { return AssertionFailure() << "Expected: (" << expr1 << ") " << op << " (" << expr2 << "), actual: " << FormatForComparisonFailureMessage(val1, val2) << " vs " << FormatForComparisonFailureMessage(val2, val1); } // A macro for implementing the helper functions needed to implement // ASSERT_?? and EXPECT_??. It is here just to avoid copy-and-paste // of similar code. // // For each templatized helper function, we also define an overloaded // version for BiggestInt in order to reduce code bloat and allow // anonymous enums to be used with {ASSERT|EXPECT}_?? when compiled // with gcc 4. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. #define GTEST_IMPL_CMP_HELPER_(op_name, op)\ template \ AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ const T1& val1, const T2& val2) {\ if (val1 op val2) {\ return AssertionSuccess();\ } else {\ return CmpHelperOpFailure(expr1, expr2, val1, val2, #op);\ }\ }\ GTEST_API_ AssertionResult CmpHelper##op_name(\ const char* expr1, const char* expr2, BiggestInt val1, BiggestInt val2) // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. // Implements the helper function for {ASSERT|EXPECT}_NE GTEST_IMPL_CMP_HELPER_(NE, !=); // Implements the helper function for {ASSERT|EXPECT}_LE GTEST_IMPL_CMP_HELPER_(LE, <=); // Implements the helper function for {ASSERT|EXPECT}_LT GTEST_IMPL_CMP_HELPER_(LT, <); // Implements the helper function for {ASSERT|EXPECT}_GE GTEST_IMPL_CMP_HELPER_(GE, >=); // Implements the helper function for {ASSERT|EXPECT}_GT GTEST_IMPL_CMP_HELPER_(GT, >); #undef GTEST_IMPL_CMP_HELPER_ // The helper function for {ASSERT|EXPECT}_STREQ. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult CmpHelperSTREQ(const char* s1_expression, const char* s2_expression, const char* s1, const char* s2); // The helper function for {ASSERT|EXPECT}_STRCASEEQ. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult CmpHelperSTRCASEEQ(const char* s1_expression, const char* s2_expression, const char* s1, const char* s2); // The helper function for {ASSERT|EXPECT}_STRNE. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, const char* s2_expression, const char* s1, const char* s2); // The helper function for {ASSERT|EXPECT}_STRCASENE. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult CmpHelperSTRCASENE(const char* s1_expression, const char* s2_expression, const char* s1, const char* s2); // Helper function for *_STREQ on wide strings. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult CmpHelperSTREQ(const char* s1_expression, const char* s2_expression, const wchar_t* s1, const wchar_t* s2); // Helper function for *_STRNE on wide strings. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, const char* s2_expression, const wchar_t* s1, const wchar_t* s2); } // namespace internal // IsSubstring() and IsNotSubstring() are intended to be used as the // first argument to {EXPECT,ASSERT}_PRED_FORMAT2(), not by // themselves. They check whether needle is a substring of haystack // (NULL is considered a substring of itself only), and return an // appropriate error message when they fail. // // The {needle,haystack}_expr arguments are the stringified // expressions that generated the two real arguments. GTEST_API_ AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const char* needle, const char* haystack); GTEST_API_ AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const wchar_t* needle, const wchar_t* haystack); GTEST_API_ AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const char* needle, const char* haystack); GTEST_API_ AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const wchar_t* needle, const wchar_t* haystack); GTEST_API_ AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const ::std::string& needle, const ::std::string& haystack); GTEST_API_ AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const ::std::string& needle, const ::std::string& haystack); #if GTEST_HAS_STD_WSTRING GTEST_API_ AssertionResult IsSubstring( const char* needle_expr, const char* haystack_expr, const ::std::wstring& needle, const ::std::wstring& haystack); GTEST_API_ AssertionResult IsNotSubstring( const char* needle_expr, const char* haystack_expr, const ::std::wstring& needle, const ::std::wstring& haystack); #endif // GTEST_HAS_STD_WSTRING namespace internal { // Helper template function for comparing floating-points. // // Template parameter: // // RawType: the raw floating-point type (either float or double) // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. template AssertionResult CmpHelperFloatingPointEQ(const char* lhs_expression, const char* rhs_expression, RawType lhs_value, RawType rhs_value) { const FloatingPoint lhs(lhs_value), rhs(rhs_value); if (lhs.AlmostEquals(rhs)) { return AssertionSuccess(); } ::std::stringstream lhs_ss; lhs_ss << std::setprecision(std::numeric_limits::digits10 + 2) << lhs_value; ::std::stringstream rhs_ss; rhs_ss << std::setprecision(std::numeric_limits::digits10 + 2) << rhs_value; return EqFailure(lhs_expression, rhs_expression, StringStreamToString(&lhs_ss), StringStreamToString(&rhs_ss), false); } // Helper function for implementing ASSERT_NEAR. // // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. GTEST_API_ AssertionResult DoubleNearPredFormat(const char* expr1, const char* expr2, const char* abs_error_expr, double val1, double val2, double abs_error); // INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. // A class that enables one to stream messages to assertion macros class GTEST_API_ AssertHelper { public: // Constructor. AssertHelper(TestPartResult::Type type, const char* file, int line, const char* message); ~AssertHelper(); // Message assignment is a semantic trick to enable assertion // streaming; see the GTEST_MESSAGE_ macro below. void operator=(const Message& message) const; private: // We put our data in a struct so that the size of the AssertHelper class can // be as small as possible. This is important because gcc is incapable of // re-using stack space even for temporary variables, so every EXPECT_EQ // reserves stack space for another AssertHelper. struct AssertHelperData { AssertHelperData(TestPartResult::Type t, const char* srcfile, int line_num, const char* msg) : type(t), file(srcfile), line(line_num), message(msg) { } TestPartResult::Type const type; const char* const file; int const line; std::string const message; private: GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelperData); }; AssertHelperData* const data_; GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelper); }; enum GTestColor { COLOR_DEFAULT, COLOR_RED, COLOR_GREEN, COLOR_YELLOW }; GTEST_API_ GTEST_ATTRIBUTE_PRINTF_(2, 3) void ColoredPrintf(GTestColor color, const char* fmt, ...); } // namespace internal // The pure interface class that all value-parameterized tests inherit from. // A value-parameterized class must inherit from both ::testing::Test and // ::testing::WithParamInterface. In most cases that just means inheriting // from ::testing::TestWithParam, but more complicated test hierarchies // may need to inherit from Test and WithParamInterface at different levels. // // This interface has support for accessing the test parameter value via // the GetParam() method. // // Use it with one of the parameter generator defining functions, like Range(), // Values(), ValuesIn(), Bool(), and Combine(). // // class FooTest : public ::testing::TestWithParam { // protected: // FooTest() { // // Can use GetParam() here. // } // ~FooTest() override { // // Can use GetParam() here. // } // void SetUp() override { // // Can use GetParam() here. // } // void TearDown override { // // Can use GetParam() here. // } // }; // TEST_P(FooTest, DoesBar) { // // Can use GetParam() method here. // Foo foo; // ASSERT_TRUE(foo.DoesBar(GetParam())); // } // INSTANTIATE_TEST_SUITE_P(OneToTenRange, FooTest, ::testing::Range(1, 10)); template class WithParamInterface { public: typedef T ParamType; virtual ~WithParamInterface() {} // The current parameter value. Is also available in the test fixture's // constructor. static const ParamType& GetParam() { GTEST_CHECK_(parameter_ != nullptr) << "GetParam() can only be called inside a value-parameterized test " << "-- did you intend to write TEST_P instead of TEST_F?"; return *parameter_; } private: // Sets parameter value. The caller is responsible for making sure the value // remains alive and unchanged throughout the current test. static void SetParam(const ParamType* parameter) { parameter_ = parameter; } // Static value used for accessing parameter during a test lifetime. static const ParamType* parameter_; // TestClass must be a subclass of WithParamInterface and Test. template friend class internal::ParameterizedTestFactory; }; template const T* WithParamInterface::parameter_ = nullptr; // Most value-parameterized classes can ignore the existence of // WithParamInterface, and can just inherit from ::testing::TestWithParam. template class TestWithParam : public Test, public WithParamInterface { }; // Macros for indicating success/failure in test code. // Skips test in runtime. // Skipping test aborts current function. // Skipped tests are neither successful nor failed. #define GTEST_SKIP() GTEST_SKIP_("Skipped") // ADD_FAILURE unconditionally adds a failure to the current test. // SUCCEED generates a success - it doesn't automatically make the // current test successful, as a test is only successful when it has // no failure. // // EXPECT_* verifies that a certain condition is satisfied. If not, // it behaves like ADD_FAILURE. In particular: // // EXPECT_TRUE verifies that a Boolean condition is true. // EXPECT_FALSE verifies that a Boolean condition is false. // // FAIL and ASSERT_* are similar to ADD_FAILURE and EXPECT_*, except // that they will also abort the current function on failure. People // usually want the fail-fast behavior of FAIL and ASSERT_*, but those // writing data-driven tests often find themselves using ADD_FAILURE // and EXPECT_* more. // Generates a nonfatal failure with a generic message. #define ADD_FAILURE() GTEST_NONFATAL_FAILURE_("Failed") // Generates a nonfatal failure at the given source file location with // a generic message. #define ADD_FAILURE_AT(file, line) \ GTEST_MESSAGE_AT_(file, line, "Failed", \ ::testing::TestPartResult::kNonFatalFailure) // Generates a fatal failure with a generic message. #define GTEST_FAIL() GTEST_FATAL_FAILURE_("Failed") // Like GTEST_FAIL(), but at the given source file location. #define GTEST_FAIL_AT(file, line) \ GTEST_MESSAGE_AT_(file, line, "Failed", \ ::testing::TestPartResult::kFatalFailure) // Define this macro to 1 to omit the definition of FAIL(), which is a // generic name and clashes with some other libraries. #if !GTEST_DONT_DEFINE_FAIL # define FAIL() GTEST_FAIL() #endif // Generates a success with a generic message. #define GTEST_SUCCEED() GTEST_SUCCESS_("Succeeded") // Define this macro to 1 to omit the definition of SUCCEED(), which // is a generic name and clashes with some other libraries. #if !GTEST_DONT_DEFINE_SUCCEED # define SUCCEED() GTEST_SUCCEED() #endif // Macros for testing exceptions. // // * {ASSERT|EXPECT}_THROW(statement, expected_exception): // Tests that the statement throws the expected exception. // * {ASSERT|EXPECT}_NO_THROW(statement): // Tests that the statement doesn't throw any exception. // * {ASSERT|EXPECT}_ANY_THROW(statement): // Tests that the statement throws an exception. #define EXPECT_THROW(statement, expected_exception) \ GTEST_TEST_THROW_(statement, expected_exception, GTEST_NONFATAL_FAILURE_) #define EXPECT_NO_THROW(statement) \ GTEST_TEST_NO_THROW_(statement, GTEST_NONFATAL_FAILURE_) #define EXPECT_ANY_THROW(statement) \ GTEST_TEST_ANY_THROW_(statement, GTEST_NONFATAL_FAILURE_) #define ASSERT_THROW(statement, expected_exception) \ GTEST_TEST_THROW_(statement, expected_exception, GTEST_FATAL_FAILURE_) #define ASSERT_NO_THROW(statement) \ GTEST_TEST_NO_THROW_(statement, GTEST_FATAL_FAILURE_) #define ASSERT_ANY_THROW(statement) \ GTEST_TEST_ANY_THROW_(statement, GTEST_FATAL_FAILURE_) // Boolean assertions. Condition can be either a Boolean expression or an // AssertionResult. For more information on how to use AssertionResult with // these macros see comments on that class. #define EXPECT_TRUE(condition) \ GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ GTEST_NONFATAL_FAILURE_) #define EXPECT_FALSE(condition) \ GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ GTEST_NONFATAL_FAILURE_) #define ASSERT_TRUE(condition) \ GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ GTEST_FATAL_FAILURE_) #define ASSERT_FALSE(condition) \ GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ GTEST_FATAL_FAILURE_) // Macros for testing equalities and inequalities. // // * {ASSERT|EXPECT}_EQ(v1, v2): Tests that v1 == v2 // * {ASSERT|EXPECT}_NE(v1, v2): Tests that v1 != v2 // * {ASSERT|EXPECT}_LT(v1, v2): Tests that v1 < v2 // * {ASSERT|EXPECT}_LE(v1, v2): Tests that v1 <= v2 // * {ASSERT|EXPECT}_GT(v1, v2): Tests that v1 > v2 // * {ASSERT|EXPECT}_GE(v1, v2): Tests that v1 >= v2 // // When they are not, Google Test prints both the tested expressions and // their actual values. The values must be compatible built-in types, // or you will get a compiler error. By "compatible" we mean that the // values can be compared by the respective operator. // // Note: // // 1. It is possible to make a user-defined type work with // {ASSERT|EXPECT}_??(), but that requires overloading the // comparison operators and is thus discouraged by the Google C++ // Usage Guide. Therefore, you are advised to use the // {ASSERT|EXPECT}_TRUE() macro to assert that two objects are // equal. // // 2. The {ASSERT|EXPECT}_??() macros do pointer comparisons on // pointers (in particular, C strings). Therefore, if you use it // with two C strings, you are testing how their locations in memory // are related, not how their content is related. To compare two C // strings by content, use {ASSERT|EXPECT}_STR*(). // // 3. {ASSERT|EXPECT}_EQ(v1, v2) is preferred to // {ASSERT|EXPECT}_TRUE(v1 == v2), as the former tells you // what the actual value is when it fails, and similarly for the // other comparisons. // // 4. Do not depend on the order in which {ASSERT|EXPECT}_??() // evaluate their arguments, which is undefined. // // 5. These macros evaluate their arguments exactly once. // // Examples: // // EXPECT_NE(Foo(), 5); // EXPECT_EQ(a_pointer, NULL); // ASSERT_LT(i, array_size); // ASSERT_GT(records.size(), 0) << "There is no record left."; #define EXPECT_EQ(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::EqHelper::Compare, val1, val2) #define EXPECT_NE(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2) #define EXPECT_LE(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) #define EXPECT_LT(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) #define EXPECT_GE(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) #define EXPECT_GT(val1, val2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) #define GTEST_ASSERT_EQ(val1, val2) \ ASSERT_PRED_FORMAT2(::testing::internal::EqHelper::Compare, val1, val2) #define GTEST_ASSERT_NE(val1, val2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2) #define GTEST_ASSERT_LE(val1, val2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) #define GTEST_ASSERT_LT(val1, val2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) #define GTEST_ASSERT_GE(val1, val2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) #define GTEST_ASSERT_GT(val1, val2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) // Define macro GTEST_DONT_DEFINE_ASSERT_XY to 1 to omit the definition of // ASSERT_XY(), which clashes with some users' own code. #if !GTEST_DONT_DEFINE_ASSERT_EQ # define ASSERT_EQ(val1, val2) GTEST_ASSERT_EQ(val1, val2) #endif #if !GTEST_DONT_DEFINE_ASSERT_NE # define ASSERT_NE(val1, val2) GTEST_ASSERT_NE(val1, val2) #endif #if !GTEST_DONT_DEFINE_ASSERT_LE # define ASSERT_LE(val1, val2) GTEST_ASSERT_LE(val1, val2) #endif #if !GTEST_DONT_DEFINE_ASSERT_LT # define ASSERT_LT(val1, val2) GTEST_ASSERT_LT(val1, val2) #endif #if !GTEST_DONT_DEFINE_ASSERT_GE # define ASSERT_GE(val1, val2) GTEST_ASSERT_GE(val1, val2) #endif #if !GTEST_DONT_DEFINE_ASSERT_GT # define ASSERT_GT(val1, val2) GTEST_ASSERT_GT(val1, val2) #endif // C-string Comparisons. All tests treat NULL and any non-NULL string // as different. Two NULLs are equal. // // * {ASSERT|EXPECT}_STREQ(s1, s2): Tests that s1 == s2 // * {ASSERT|EXPECT}_STRNE(s1, s2): Tests that s1 != s2 // * {ASSERT|EXPECT}_STRCASEEQ(s1, s2): Tests that s1 == s2, ignoring case // * {ASSERT|EXPECT}_STRCASENE(s1, s2): Tests that s1 != s2, ignoring case // // For wide or narrow string objects, you can use the // {ASSERT|EXPECT}_??() macros. // // Don't depend on the order in which the arguments are evaluated, // which is undefined. // // These macros evaluate their arguments exactly once. #define EXPECT_STREQ(s1, s2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, s1, s2) #define EXPECT_STRNE(s1, s2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) #define EXPECT_STRCASEEQ(s1, s2) \ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, s1, s2) #define EXPECT_STRCASENE(s1, s2)\ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) #define ASSERT_STREQ(s1, s2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, s1, s2) #define ASSERT_STRNE(s1, s2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) #define ASSERT_STRCASEEQ(s1, s2) \ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, s1, s2) #define ASSERT_STRCASENE(s1, s2)\ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) // Macros for comparing floating-point numbers. // // * {ASSERT|EXPECT}_FLOAT_EQ(val1, val2): // Tests that two float values are almost equal. // * {ASSERT|EXPECT}_DOUBLE_EQ(val1, val2): // Tests that two double values are almost equal. // * {ASSERT|EXPECT}_NEAR(v1, v2, abs_error): // Tests that v1 and v2 are within the given distance to each other. // // Google Test uses ULP-based comparison to automatically pick a default // error bound that is appropriate for the operands. See the // FloatingPoint template class in gtest-internal.h if you are // interested in the implementation details. #define EXPECT_FLOAT_EQ(val1, val2)\ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ val1, val2) #define EXPECT_DOUBLE_EQ(val1, val2)\ EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ val1, val2) #define ASSERT_FLOAT_EQ(val1, val2)\ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ val1, val2) #define ASSERT_DOUBLE_EQ(val1, val2)\ ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ val1, val2) #define EXPECT_NEAR(val1, val2, abs_error)\ EXPECT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ val1, val2, abs_error) #define ASSERT_NEAR(val1, val2, abs_error)\ ASSERT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ val1, val2, abs_error) // These predicate format functions work on floating-point values, and // can be used in {ASSERT|EXPECT}_PRED_FORMAT2*(), e.g. // // EXPECT_PRED_FORMAT2(testing::DoubleLE, Foo(), 5.0); // Asserts that val1 is less than, or almost equal to, val2. Fails // otherwise. In particular, it fails if either val1 or val2 is NaN. GTEST_API_ AssertionResult FloatLE(const char* expr1, const char* expr2, float val1, float val2); GTEST_API_ AssertionResult DoubleLE(const char* expr1, const char* expr2, double val1, double val2); #if GTEST_OS_WINDOWS // Macros that test for HRESULT failure and success, these are only useful // on Windows, and rely on Windows SDK macros and APIs to compile. // // * {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED}(expr) // // When expr unexpectedly fails or succeeds, Google Test prints the // expected result and the actual result with both a human-readable // string representation of the error, if available, as well as the // hex result code. # define EXPECT_HRESULT_SUCCEEDED(expr) \ EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) # define ASSERT_HRESULT_SUCCEEDED(expr) \ ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) # define EXPECT_HRESULT_FAILED(expr) \ EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) # define ASSERT_HRESULT_FAILED(expr) \ ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) #endif // GTEST_OS_WINDOWS // Macros that execute statement and check that it doesn't generate new fatal // failures in the current thread. // // * {ASSERT|EXPECT}_NO_FATAL_FAILURE(statement); // // Examples: // // EXPECT_NO_FATAL_FAILURE(Process()); // ASSERT_NO_FATAL_FAILURE(Process()) << "Process() failed"; // #define ASSERT_NO_FATAL_FAILURE(statement) \ GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_FATAL_FAILURE_) #define EXPECT_NO_FATAL_FAILURE(statement) \ GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_NONFATAL_FAILURE_) // Causes a trace (including the given source file path and line number, // and the given message) to be included in every test failure message generated // by code in the scope of the lifetime of an instance of this class. The effect // is undone with the destruction of the instance. // // The message argument can be anything streamable to std::ostream. // // Example: // testing::ScopedTrace trace("file.cc", 123, "message"); // class GTEST_API_ ScopedTrace { public: // The c'tor pushes the given source file location and message onto // a trace stack maintained by Google Test. // Template version. Uses Message() to convert the values into strings. // Slow, but flexible. template ScopedTrace(const char* file, int line, const T& message) { PushTrace(file, line, (Message() << message).GetString()); } // Optimize for some known types. ScopedTrace(const char* file, int line, const char* message) { PushTrace(file, line, message ? message : "(null)"); } ScopedTrace(const char* file, int line, const std::string& message) { PushTrace(file, line, message); } // The d'tor pops the info pushed by the c'tor. // // Note that the d'tor is not virtual in order to be efficient. // Don't inherit from ScopedTrace! ~ScopedTrace(); private: void PushTrace(const char* file, int line, std::string message); GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedTrace); } GTEST_ATTRIBUTE_UNUSED_; // A ScopedTrace object does its job in its // c'tor and d'tor. Therefore it doesn't // need to be used otherwise. // Causes a trace (including the source file path, the current line // number, and the given message) to be included in every test failure // message generated by code in the current scope. The effect is // undone when the control leaves the current scope. // // The message argument can be anything streamable to std::ostream. // // In the implementation, we include the current line number as part // of the dummy variable name, thus allowing multiple SCOPED_TRACE()s // to appear in the same block - as long as they are on different // lines. // // Assuming that each thread maintains its own stack of traces. // Therefore, a SCOPED_TRACE() would (correctly) only affect the // assertions in its own thread. #define SCOPED_TRACE(message) \ ::testing::ScopedTrace GTEST_CONCAT_TOKEN_(gtest_trace_, __LINE__)(\ __FILE__, __LINE__, (message)) // Compile-time assertion for type equality. // StaticAssertTypeEq() compiles if and only if type1 and type2 // are the same type. The value it returns is not interesting. // // Instead of making StaticAssertTypeEq a class template, we make it a // function template that invokes a helper class template. This // prevents a user from misusing StaticAssertTypeEq by // defining objects of that type. // // CAVEAT: // // When used inside a method of a class template, // StaticAssertTypeEq() is effective ONLY IF the method is // instantiated. For example, given: // // template class Foo { // public: // void Bar() { testing::StaticAssertTypeEq(); } // }; // // the code: // // void Test1() { Foo foo; } // // will NOT generate a compiler error, as Foo::Bar() is never // actually instantiated. Instead, you need: // // void Test2() { Foo foo; foo.Bar(); } // // to cause a compiler error. template constexpr bool StaticAssertTypeEq() noexcept { static_assert(std::is_same::value, "type1 and type2 are not the same type"); return true; } // Defines a test. // // The first parameter is the name of the test suite, and the second // parameter is the name of the test within the test suite. // // The convention is to end the test suite name with "Test". For // example, a test suite for the Foo class can be named FooTest. // // Test code should appear between braces after an invocation of // this macro. Example: // // TEST(FooTest, InitializesCorrectly) { // Foo foo; // EXPECT_TRUE(foo.StatusIsOK()); // } // Note that we call GetTestTypeId() instead of GetTypeId< // ::testing::Test>() here to get the type ID of testing::Test. This // is to work around a suspected linker bug when using Google Test as // a framework on Mac OS X. The bug causes GetTypeId< // ::testing::Test>() to return different values depending on whether // the call is from the Google Test framework itself or from user test // code. GetTestTypeId() is guaranteed to always return the same // value, as it always calls GetTypeId<>() from the Google Test // framework. #define GTEST_TEST(test_suite_name, test_name) \ GTEST_TEST_(test_suite_name, test_name, ::testing::Test, \ ::testing::internal::GetTestTypeId()) // Define this macro to 1 to omit the definition of TEST(), which // is a generic name and clashes with some other libraries. #if !GTEST_DONT_DEFINE_TEST #define TEST(test_suite_name, test_name) GTEST_TEST(test_suite_name, test_name) #endif // Defines a test that uses a test fixture. // // The first parameter is the name of the test fixture class, which // also doubles as the test suite name. The second parameter is the // name of the test within the test suite. // // A test fixture class must be declared earlier. The user should put // the test code between braces after using this macro. Example: // // class FooTest : public testing::Test { // protected: // void SetUp() override { b_.AddElement(3); } // // Foo a_; // Foo b_; // }; // // TEST_F(FooTest, InitializesCorrectly) { // EXPECT_TRUE(a_.StatusIsOK()); // } // // TEST_F(FooTest, ReturnsElementCountCorrectly) { // EXPECT_EQ(a_.size(), 0); // EXPECT_EQ(b_.size(), 1); // } // // GOOGLETEST_CM0011 DO NOT DELETE #define TEST_F(test_fixture, test_name)\ GTEST_TEST_(test_fixture, test_name, test_fixture, \ ::testing::internal::GetTypeId()) // Returns a path to temporary directory. // Tries to determine an appropriate directory for the platform. GTEST_API_ std::string TempDir(); #ifdef _MSC_VER # pragma warning(pop) #endif // Dynamically registers a test with the framework. // // This is an advanced API only to be used when the `TEST` macros are // insufficient. The macros should be preferred when possible, as they avoid // most of the complexity of calling this function. // // The `factory` argument is a factory callable (move-constructible) object or // function pointer that creates a new instance of the Test object. It // handles ownership to the caller. The signature of the callable is // `Fixture*()`, where `Fixture` is the test fixture class for the test. All // tests registered with the same `test_suite_name` must return the same // fixture type. This is checked at runtime. // // The framework will infer the fixture class from the factory and will call // the `SetUpTestSuite` and `TearDownTestSuite` for it. // // Must be called before `RUN_ALL_TESTS()` is invoked, otherwise behavior is // undefined. // // Use case example: // // class MyFixture : public ::testing::Test { // public: // // All of these optional, just like in regular macro usage. // static void SetUpTestSuite() { ... } // static void TearDownTestSuite() { ... } // void SetUp() override { ... } // void TearDown() override { ... } // }; // // class MyTest : public MyFixture { // public: // explicit MyTest(int data) : data_(data) {} // void TestBody() override { ... } // // private: // int data_; // }; // // void RegisterMyTests(const std::vector& values) { // for (int v : values) { // ::testing::RegisterTest( // "MyFixture", ("Test" + std::to_string(v)).c_str(), nullptr, // std::to_string(v).c_str(), // __FILE__, __LINE__, // // Important to use the fixture type as the return type here. // [=]() -> MyFixture* { return new MyTest(v); }); // } // } // ... // int main(int argc, char** argv) { // std::vector values_to_test = LoadValuesFromConfig(); // RegisterMyTests(values_to_test); // ... // return RUN_ALL_TESTS(); // } // template TestInfo* RegisterTest(const char* test_suite_name, const char* test_name, const char* type_param, const char* value_param, const char* file, int line, Factory factory) { using TestT = typename std::remove_pointer::type; class FactoryImpl : public internal::TestFactoryBase { public: explicit FactoryImpl(Factory f) : factory_(std::move(f)) {} Test* CreateTest() override { return factory_(); } private: Factory factory_; }; return internal::MakeAndRegisterTestInfo( test_suite_name, test_name, type_param, value_param, internal::CodeLocation(file, line), internal::GetTypeId(), internal::SuiteApiResolver::GetSetUpCaseOrSuite(file, line), internal::SuiteApiResolver::GetTearDownCaseOrSuite(file, line), new FactoryImpl{std::move(factory)}); } } // namespace testing // Use this function in main() to run all tests. It returns 0 if all // tests are successful, or 1 otherwise. // // RUN_ALL_TESTS() should be invoked after the command line has been // parsed by InitGoogleTest(). // // This function was formerly a macro; thus, it is in the global // namespace and has an all-caps name. int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_; inline int RUN_ALL_TESTS() { return ::testing::UnitTest::GetInstance()->Run(); } GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 #endif // GTEST_INCLUDE_GTEST_GTEST_H_ openucx-ucc-ec0bc8a/test/gtest/asym_mem/0000775000175000017500000000000015133731560020616 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/asym_mem/test_asymmetric_memory.cc0000664000175000017500000004645715133731560025751 0ustar alastairalastair/** * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "common/test_ucc.h" #ifdef HAVE_CUDA using Param = std::tuple; class test_asymmetric_memory : public ucc::test, public ::testing::WithParamInterface { public: UccCollCtxVec ctxs; void data_init(ucc_coll_type_t coll_type, ucc_memory_type_t src_mem_type, ucc_memory_type_t dst_mem_type, UccTeam_h team, bool persistent = false) { ucc_rank_t tsize = team->procs.size(); int root = 0; size_t msglen = 2048; size_t src_modifier = 1; size_t dst_modifier = 1; ctxs.resize(tsize); if (coll_type == UCC_COLL_TYPE_GATHER) { dst_modifier = tsize; } else if (coll_type == UCC_COLL_TYPE_SCATTER) { src_modifier = tsize; } for (int i = 0; i < tsize; i++) { ctxs[i] = (gtest_ucc_coll_ctx_t*) calloc(1, sizeof(gtest_ucc_coll_ctx_t)); ucc_coll_args_t *coll = (ucc_coll_args_t*) calloc(1, sizeof(ucc_coll_args_t)); ctxs[i]->args = coll; coll->coll_type = coll_type; coll->src.info.mem_type = src_mem_type; coll->src.info.count = (ucc_count_t)msglen * src_modifier; coll->src.info.datatype = UCC_DT_INT8; coll->root = root; if (persistent) { coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll->flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } if (i == root || coll_type != UCC_COLL_TYPE_SCATTER) { UCC_CHECK(ucc_mc_alloc(&ctxs[i]->src_mc_header, msglen * src_modifier, src_mem_type)); coll->src.info.buffer = ctxs[i]->src_mc_header->addr; ctxs[i]->init_buf = ucc_malloc(msglen * src_modifier, "init buf"); EXPECT_NE(ctxs[i]->init_buf, nullptr); uint8_t *sbuf = (uint8_t*)ctxs[i]->init_buf; for (int j = 0; j < msglen * src_modifier; j++) { sbuf[j] = (uint8_t) 1; } UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[i]->init_buf, msglen * src_modifier, src_mem_type, UCC_MEMORY_TYPE_HOST)); } ctxs[i]->rbuf_size = msglen * dst_modifier; if (i == root || coll_type == UCC_COLL_TYPE_SCATTER) { UCC_CHECK(ucc_mc_alloc(&ctxs[i]->dst_mc_header, ctxs[i]->rbuf_size, dst_mem_type)); coll->dst.info.buffer = ctxs[i]->dst_mc_header->addr; coll->dst.info.count = (ucc_count_t)ctxs[i]->rbuf_size; coll->dst.info.datatype = UCC_DT_INT8; coll->dst.info.mem_type = dst_mem_type; } } } void data_fini() { for (int i = 0; i < ctxs.size(); i++) { gtest_ucc_coll_ctx_t *ctx = ctxs[i]; if (!ctx) { continue; } ucc_coll_args_t* coll = ctx->args; if (i == coll->root || coll->coll_type != UCC_COLL_TYPE_SCATTER) { ucc_free(ctx->init_buf); UCC_CHECK(ucc_mc_free(ctx->src_mc_header)); } if (i == coll->root || coll->coll_type == UCC_COLL_TYPE_SCATTER) { UCC_CHECK(ucc_mc_free(ctx->dst_mc_header)); } free(coll); free(ctx); } ctxs.clear(); } bool data_validate(uint8_t data = 1) { bool ret = true; int root = 0; uint8_t result = data; ucc_memory_type_t dst_mem_type; uint8_t *rst; if (ctxs[0]->args->coll_type == UCC_COLL_TYPE_REDUCE) { result *= (uint8_t) ctxs.size(); } for (int i = 0; i < ctxs.size(); i++) { if (!ctxs[i]) { continue; } root = ctxs[i]->args->root; if (i == root || ctxs[i]->args->coll_type == UCC_COLL_TYPE_SCATTER) { dst_mem_type = ctxs[i]->args->dst.info.mem_type; rst = (uint8_t*) ucc_malloc(ctxs[i]->rbuf_size, "validation buf"); EXPECT_NE(rst, nullptr); UCC_CHECK(ucc_mc_memcpy(rst, ctxs[i]->args->dst.info.buffer, ctxs[i]->rbuf_size, UCC_MEMORY_TYPE_HOST, dst_mem_type)); for (int j = 0; j < ctxs[i]->rbuf_size; j++) { if (result != rst[j]) { ret = false; break; } } ucc_free(rst); } } return ret; } void data_update(uint8_t data) { ucc_rank_t tsize = ctxs.size(); size_t msglen = 2048; size_t src_modifier = 1; ucc_coll_type_t coll_type = ctxs[0]->args->coll_type; int root = ctxs[0]->args->root; ucc_memory_type_t src_mem_type = ctxs[0]->args->src.info.mem_type; if (coll_type == UCC_COLL_TYPE_SCATTER) { src_modifier = tsize; } for (int i = 0; i < tsize; i++) { if (i == root || coll_type != UCC_COLL_TYPE_SCATTER) { ucc_coll_args_t *coll = ctxs[i]->args; uint8_t *sbuf = (uint8_t*)ctxs[i]->init_buf; for (int j = 0; j < msglen * src_modifier; j++) { sbuf[j] = (uint8_t) data; } UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[i]->init_buf, msglen * src_modifier, src_mem_type, UCC_MEMORY_TYPE_HOST)); } } } }; class test_asymmetric_memory_v : public ucc::test, public ::testing::WithParamInterface { public: UccCollCtxVec ctxs; void data_init(ucc_coll_type_t coll_type, ucc_memory_type_t src_mem_type, ucc_memory_type_t dst_mem_type, UccTeam_h team) { int nprocs = team->n_procs; size_t count = 2048; ucc_rank_t root = 0; ucc_coll_args_t *coll; int *counts, *displs; size_t my_count, all_counts; ctxs.resize(nprocs); for (auto r = 0; r < nprocs; r++) { coll = (ucc_coll_args_t *)calloc(1, sizeof(ucc_coll_args_t)); my_count = (nprocs - r) * count; ctxs[r] = (gtest_ucc_coll_ctx_t *)calloc(1, sizeof(gtest_ucc_coll_ctx_t)); ctxs[r]->args = coll; coll->mask = 0; coll->flags = 0; coll->coll_type = coll_type; coll->root = root; if (coll_type == UCC_COLL_TYPE_GATHERV) { coll->src.info.mem_type = src_mem_type; coll->src.info.count = (ucc_count_t)my_count; coll->src.info.datatype = UCC_DT_INT8; ctxs[r]->init_buf = ucc_malloc(ucc_dt_size(UCC_DT_INT8) * my_count, "init buf"); EXPECT_NE(ctxs[r]->init_buf, nullptr); for (int i = 0; i < my_count * ucc_dt_size(UCC_DT_INT8); i++) { uint8_t *sbuf = (uint8_t *)ctxs[r]->init_buf; sbuf[i] = ((i + r) % 256); } if (r == root) { all_counts = 0; counts = (int*)malloc(sizeof(int) * nprocs); EXPECT_NE(counts, nullptr); displs = (int*)malloc(sizeof(int) * nprocs); EXPECT_NE(displs, nullptr); for (int i = 0; i < nprocs; i++) { counts[i] = (nprocs - i) * count; displs[i] = all_counts; all_counts += counts[i]; } coll->dst.info_v.mem_type = dst_mem_type; coll->dst.info_v.counts = (ucc_count_t *)counts; coll->dst.info_v.displacements = (ucc_aint_t *)displs; coll->dst.info_v.datatype = UCC_DT_INT8; ctxs[r]->rbuf_size = ucc_dt_size(UCC_DT_INT8) * all_counts; UCC_CHECK(ucc_mc_alloc(&ctxs[r]->dst_mc_header, ctxs[r]->rbuf_size, dst_mem_type)); coll->dst.info_v.buffer = ctxs[r]->dst_mc_header->addr; } UCC_CHECK(ucc_mc_alloc(&ctxs[r]->src_mc_header, ucc_dt_size(UCC_DT_INT8) * my_count, src_mem_type)); coll->src.info.buffer = ctxs[r]->src_mc_header->addr; UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[r]->init_buf, ucc_dt_size(UCC_DT_INT8) * my_count, src_mem_type, UCC_MEMORY_TYPE_HOST)); } else { // scatterv coll->dst.info.mem_type = dst_mem_type; coll->dst.info.count = (ucc_count_t)my_count; coll->dst.info.datatype = UCC_DT_INT8; if (r == root) { all_counts = 0; counts = (int*)malloc(sizeof(int) * nprocs); EXPECT_NE(counts, nullptr); displs = (int*)malloc(sizeof(int) * nprocs); EXPECT_NE(displs, nullptr); for (int i = 0; i < nprocs; i++) { counts[i] = (nprocs - i) * count; displs[i] = all_counts; all_counts += counts[i]; } ctxs[r]->init_buf = ucc_malloc(ucc_dt_size(UCC_DT_INT8) * all_counts, "init buf"); EXPECT_NE(ctxs[r]->init_buf, nullptr); uint8_t *sbuf = (uint8_t*)ctxs[r]->init_buf; for (int p = 0; p < nprocs; p++) { for (int i = 0; i < ucc_dt_size(UCC_DT_INT8) * counts[p]; i++) { sbuf[(displs[p] * ucc_dt_size(UCC_DT_INT8) + i)] = (uint8_t)((i + p) % 256); } } coll->src.info_v.mem_type = src_mem_type; coll->src.info_v.counts = (ucc_count_t *)counts; coll->src.info_v.displacements = (ucc_aint_t *)displs; coll->src.info_v.datatype = UCC_DT_INT8; UCC_CHECK(ucc_mc_alloc(&ctxs[r]->src_mc_header, ucc_dt_size(UCC_DT_INT8) * all_counts, src_mem_type)); coll->src.info_v.buffer = ctxs[r]->src_mc_header->addr; UCC_CHECK(ucc_mc_memcpy(coll->src.info_v.buffer, ctxs[r]->init_buf, ucc_dt_size(UCC_DT_INT8) * all_counts, src_mem_type, UCC_MEMORY_TYPE_HOST)); } UCC_CHECK(ucc_mc_alloc(&ctxs[r]->dst_mc_header, ucc_dt_size(UCC_DT_INT8) * my_count, dst_mem_type)); coll->dst.info.buffer = ctxs[r]->dst_mc_header->addr; } } } bool data_validate() { bool ret = true; int root = ctxs[0]->args->root; int *displs = (int*)ctxs[root]->args->dst.info_v.displacements; size_t dt_size; ucc_memory_type_t dst_mem_type; ucc_count_t my_count; uint8_t *dsts; if (ctxs[root]->args->coll_type == UCC_COLL_TYPE_GATHERV) { dt_size = ucc_dt_size(ctxs[root]->args->src.info.datatype); dst_mem_type = ctxs[root]->args->dst.info_v.mem_type; if (UCC_MEMORY_TYPE_HOST != dst_mem_type) { dsts = (uint8_t *)ucc_malloc(ctxs[root]->rbuf_size, "dsts buf"); ucc_assert(dsts != nullptr); UCC_CHECK(ucc_mc_memcpy(dsts, ctxs[root]->args->dst.info_v.buffer, ctxs[root]->rbuf_size, UCC_MEMORY_TYPE_HOST, dst_mem_type)); } else { dsts = (uint8_t *)ctxs[root]->args->dst.info_v.buffer; } for (int r = 0; r < ctxs.size(); r++) { my_count = ctxs[r]->args->src.info.count; for (int i = 0; i < my_count * dt_size; i++) { if ((uint8_t)((i + r) % 256) != dsts[(displs[r] * dt_size + i)]) { ret = false; break; } } } if (UCC_MEMORY_TYPE_HOST != dst_mem_type) { ucc_free(dsts); } } else { // scatterv dst_mem_type = ctxs[root]->args->dst.info.mem_type; for (auto r = 0; r < ctxs.size(); r++) { dt_size = ucc_dt_size((ctxs[r])->args->dst.info.datatype); my_count = (ctxs[r])->args->dst.info.count; if (UCC_MEMORY_TYPE_HOST != dst_mem_type) { dsts = (uint8_t *)ucc_malloc(my_count * dt_size, "dsts buf"); ucc_assert(dsts != nullptr); UCC_CHECK(ucc_mc_memcpy(dsts, ctxs[r]->args->dst.info.buffer, my_count * dt_size, UCC_MEMORY_TYPE_HOST, dst_mem_type)); } else { dsts = (uint8_t *)ctxs[r]->args->dst.info.buffer; } for (int i = 0; i < my_count * dt_size; i++) { if ((uint8_t)((i + r) % 256) != dsts[i]) { ret = false; break; } } if (UCC_MEMORY_TYPE_HOST != dst_mem_type) { ucc_free(dsts); if (!ret) { break; } } } } return ret; } void data_fini() { int root = ctxs[0]->args->root; for (auto r = 0; r < ctxs.size(); r++) { ucc_coll_args_t *coll = ctxs[r]->args; if (coll->coll_type == UCC_COLL_TYPE_GATHERV) { if (r == root) { UCC_CHECK(ucc_mc_free(ctxs[r]->dst_mc_header)); free(coll->dst.info_v.counts); free(coll->dst.info_v.displacements); } UCC_CHECK(ucc_mc_free(ctxs[r]->src_mc_header)); } else { // scatterv if (r == root) { UCC_CHECK(ucc_mc_free(ctxs[r]->src_mc_header)); free(coll->src.info_v.counts); free(coll->src.info_v.displacements); } UCC_CHECK(ucc_mc_free(ctxs[r]->dst_mc_header)); } ucc_free(ctxs[r]->init_buf); free(coll); free(ctxs[r]); } ctxs.clear(); } }; #define TEST_ASYM_DECLARE \ const ucc_coll_type_t coll_type = std::get<0>(GetParam()); \ const ucc_memory_type_t src_mem_type = std::get<1>(GetParam()); \ const ucc_memory_type_t dst_mem_type = std::get<2>(GetParam()); \ const int n_procs = std::get<3>(GetParam()); \ \ UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL); \ UccTeam_h team = job.create_team(n_procs); \ \ data_init(coll_type, src_mem_type, dst_mem_type, team); \ UccReq req(team, ctxs); \ if (req.status != UCC_OK) { \ data_fini(); \ GTEST_SKIP() << "ucc_collective_init returned " \ << ucc_status_string(req.status); \ } \ req.start(); \ req.wait(); \ EXPECT_EQ(true, data_validate()); \ data_fini(); UCC_TEST_P(test_asymmetric_memory, single) { TEST_ASYM_DECLARE } UCC_TEST_P(test_asymmetric_memory, persistent) { const ucc_coll_type_t coll_type = std::get<0>(GetParam()); const ucc_memory_type_t src_mem_type = std::get<1>(GetParam()); const ucc_memory_type_t dst_mem_type = std::get<2>(GetParam()); const int n_procs = std::get<3>(GetParam()); int times = 3; UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL); UccTeam_h team = job.create_team(n_procs); data_init(coll_type, src_mem_type, dst_mem_type, team, /*persistent*/true); UccReq req(team, ctxs); if (req.status != UCC_OK) { data_fini(); GTEST_SKIP() << "ucc_collective_init returned " << ucc_status_string(req.status); } for (; times > 0; times--) { data_update(times); // Set each element in src to times req.start(); req.wait(); EXPECT_EQ(true, data_validate(times)); // Check that the dst was correct based on times } data_fini(); } INSTANTIATE_TEST_CASE_P ( , test_asymmetric_memory, ::testing::Combine ( ::testing::Values(UCC_COLL_TYPE_REDUCE, UCC_COLL_TYPE_GATHER, UCC_COLL_TYPE_SCATTER), // coll type (scatter may be skipped because tl/ucp does not support scatter) ::testing::Values(UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_CUDA), // src mem type ::testing::Values(UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_CUDA), // dst mem type ::testing::Values(8) // n_procs ) ); UCC_TEST_P(test_asymmetric_memory_v, single_v) { TEST_ASYM_DECLARE } INSTANTIATE_TEST_CASE_P ( , test_asymmetric_memory_v, ::testing::Combine ( ::testing::Values(UCC_COLL_TYPE_GATHERV, UCC_COLL_TYPE_SCATTERV), // coll type ::testing::Values(UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_CUDA), // src mem type ::testing::Values(UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_CUDA), // dst mem type ::testing::Values(8) // n_procs ) ); #endif openucx-ucc-ec0bc8a/test/gtest/core/0000775000175000017500000000000015133731560017737 5ustar alastairalastairopenucx-ucc-ec0bc8a/test/gtest/core/test_mem_map.cc0000664000175000017500000003022515133731560022722 0ustar alastairalastair/** * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_mem_map.h" #include #include test_mem_map::test_mem_map() : ctx_h(nullptr), ctx_config(nullptr) { memset(&ctx_params, 0, sizeof(ctx_params)); } test_mem_map::~test_mem_map() { } void test_mem_map::SetUp() { test_context_config::SetUp(); EXPECT_EQ(UCC_OK, ucc_context_config_read(lib_h, NULL, &ctx_config)); ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_TYPE; ctx_params.type = UCC_CONTEXT_EXCLUSIVE; EXPECT_EQ(UCC_OK, ucc_context_create(lib_h, &ctx_params, ctx_config, &ctx_h)); } void test_mem_map::TearDown() { if (ctx_h) { EXPECT_EQ(UCC_OK, ucc_context_destroy(ctx_h)); } if (ctx_config) { ucc_context_config_release(ctx_config); } test_context_config::TearDown(); } test_mem_map_export::test_mem_map_export() : test_buffer(nullptr), buffer_size(0) { memset(&map_params, 0, sizeof(map_params)); memset(&segment, 0, sizeof(segment)); } test_mem_map_export::~test_mem_map_export() { } void test_mem_map_export::SetUp() { test_mem_map::SetUp(); /* Allocate test buffer */ buffer_size = 1024 * 1024; /* 1MB */ test_buffer = malloc(buffer_size); ASSERT_NE(nullptr, test_buffer); /* Initialize buffer with test data */ memset(test_buffer, 0xAA, buffer_size); /* Set up memory map parameters */ segment.address = test_buffer; segment.len = buffer_size; map_params.segments = &segment; map_params.n_segments = 1; } void test_mem_map_export::TearDown() { if (test_buffer) { free(test_buffer); test_buffer = nullptr; } test_mem_map::TearDown(); } test_mem_map_import::test_mem_map_import() : test_buffer(nullptr), buffer_size(0), memh(nullptr), memh_size(0) { } test_mem_map_import::~test_mem_map_import() { } void test_mem_map_import::SetUp() { test_mem_map::SetUp(); /* Allocate test buffer */ buffer_size = 1024 * 1024; /* 1MB */ test_buffer = malloc(buffer_size); ASSERT_NE(nullptr, test_buffer); /* Initialize buffer with test data */ memset(test_buffer, 0xBB, buffer_size); } void test_mem_map_import::TearDown() { if (test_buffer) { free(test_buffer); test_buffer = nullptr; } test_mem_map::TearDown(); } /* Test basic memory map export functionality */ UCC_TEST_F(test_mem_map_export, basic_export) { ucc_mem_map_mem_h memh = nullptr; size_t memh_size = 0; EXPECT_EQ(UCC_OK, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh)); EXPECT_NE(nullptr, memh); EXPECT_GT(memh_size, 0); /* Test unmap */ EXPECT_EQ(UCC_OK, ucc_mem_unmap(&memh)); /* Note: ucc_mem_unmap doesn't set memh to nullptr, it only frees the memory */ } /* Test memory map export with different buffer sizes */ UCC_TEST_F(test_mem_map_export, different_sizes) { std::vector sizes = {1024, 4096, 65536, 1024 * 1024}; for (auto size : sizes) { /* Reallocate buffer with new size */ if (test_buffer) { free(test_buffer); } test_buffer = malloc(size); ASSERT_NE(nullptr, test_buffer); memset(test_buffer, 0xCC, size); segment.address = test_buffer; segment.len = size; ucc_mem_map_mem_h memh = nullptr; size_t memh_size = 0; EXPECT_EQ(UCC_OK, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh)); EXPECT_NE(nullptr, memh); EXPECT_GT(memh_size, 0); EXPECT_EQ(UCC_OK, ucc_mem_unmap(&memh)); } } /* Test memory map export with multiple segments (should fail as UCC only supports one segment) */ UCC_TEST_F(test_mem_map_export, multiple_segments) { ucc_mem_map_mem_h memh = nullptr; size_t memh_size = 0; ucc_mem_map_t segments[2]; ucc_mem_map_params_t multi_params; /* Create two segments */ segments[0].address = test_buffer; segments[0].len = buffer_size / 2; segments[1].address = (char *)test_buffer + buffer_size / 2; segments[1].len = buffer_size / 2; multi_params.segments = segments; multi_params.n_segments = 2; /* This should fail as UCC only supports one segment per call */ EXPECT_EQ(UCC_ERR_INVALID_PARAM, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &multi_params, &memh_size, &memh)); EXPECT_EQ(nullptr, memh); } /* Test memory map export with invalid parameters */ UCC_TEST_F(test_mem_map_export, invalid_params) { ucc_mem_map_mem_h memh = nullptr; size_t memh_size = 0; /* Test with NULL params */ EXPECT_EQ(UCC_ERR_INVALID_PARAM, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, nullptr, &memh_size, &memh)); /* Test with invalid mode */ ucc_mem_map_mode_t invalid_mode = UCC_MEM_MAP_MODE_LAST; EXPECT_EQ(UCC_ERR_INVALID_PARAM, ucc_mem_map(ctx_h, invalid_mode, &map_params, &memh_size, &memh)); } /* Test memory map export with zero length buffer */ UCC_TEST_F(test_mem_map_export, zero_length) { ucc_mem_map_mem_h memh = nullptr; size_t memh_size = 0; segment.len = 0; /* This might succeed or fail depending on implementation */ ucc_status_t status = ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh); if (status == UCC_OK) { EXPECT_NE(nullptr, memh); EXPECT_EQ(UCC_OK, ucc_mem_unmap(&memh)); } } /* Test memory map import functionality */ UCC_TEST_F(test_mem_map_import, basic_import) { ucc_mem_map_mem_h export_memh = nullptr; size_t export_memh_size = 0; ucc_mem_map_mem_h import_memh = nullptr; size_t import_memh_size = 0; ucc_mem_map_t export_segment; ucc_mem_map_params_t export_params; ucc_status_t export_status; ucc_status_t import_status; export_segment.address = test_buffer; export_segment.len = buffer_size; export_params.segments = &export_segment; export_params.n_segments = 1; /* Export the memory handle */ export_status = ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &export_params, &export_memh_size, &export_memh); if (export_status != UCC_OK) { /* If export fails, skip the test */ GTEST_SKIP() << "Export failed, skipping import test"; return; } EXPECT_NE(nullptr, export_memh); EXPECT_GT(export_memh_size, 0); /* For import, we need to create a new memory handle with the exported data */ /* The import function expects the memh to be pre-allocated with the exported data */ import_memh = (ucc_mem_map_mem_h)malloc(export_memh_size); ASSERT_NE(nullptr, import_memh); memcpy(import_memh, export_memh, export_memh_size); import_status = ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_IMPORT, &export_params, &import_memh_size, &import_memh); if (import_status == UCC_OK) { EXPECT_NE(nullptr, import_memh); /* Cleanup import */ ucc_mem_unmap(&import_memh); } else { /* Import might not be supported, which is acceptable */ EXPECT_TRUE(import_status == UCC_ERR_NOT_SUPPORTED || import_status == UCC_ERR_NOT_IMPLEMENTED); /* Clean up the allocated memory if import failed */ free(import_memh); } /* Cleanup export */ ucc_mem_unmap(&export_memh); } /* Test memory map import with different buffer sizes */ UCC_TEST_F(test_mem_map_import, import_different_sizes) { std::vector sizes = {1024, 4096, 65536, 1024 * 1024}; ucc_mem_map_mem_h export_memh = nullptr; size_t export_memh_size = 0; ucc_mem_map_t export_segment; ucc_mem_map_params_t export_params; ucc_mem_map_mem_h import_memh; size_t import_memh_size; ucc_status_t import_status; for (auto size : sizes) { /* Reallocate buffer with new size */ if (test_buffer) { free(test_buffer); } test_buffer = malloc(size); ASSERT_NE(nullptr, test_buffer); memset(test_buffer, 0xDD, size); export_segment.address = test_buffer; export_segment.len = size; export_params.segments = &export_segment; export_params.n_segments = 1; /* Export the memory handle */ ucc_status_t export_status = ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &export_params, &export_memh_size, &export_memh); if (export_status != UCC_OK) { continue; /* Skip this size if export fails */ } EXPECT_NE(nullptr, export_memh); EXPECT_GT(export_memh_size, 0); /* Test import */ import_memh = (ucc_mem_map_mem_h)malloc(export_memh_size); ASSERT_NE(nullptr, import_memh); memcpy(import_memh, export_memh, export_memh_size); import_status = ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_IMPORT, &export_params, &import_memh_size, &import_memh); if (import_status == UCC_OK) { EXPECT_NE(nullptr, import_memh); ucc_mem_unmap(&import_memh); } else { /* Import might not be supported for all sizes */ EXPECT_TRUE(import_status == UCC_ERR_NOT_SUPPORTED || import_status == UCC_ERR_NOT_IMPLEMENTED); /* Clean up the allocated memory if import failed */ free(import_memh); } /* Cleanup export */ ucc_mem_unmap(&export_memh); } } /* Test memory map import with invalid parameters */ UCC_TEST_F(test_mem_map_import, import_invalid_params) { /* Test import with NULL params */ ucc_mem_map_mem_h memh = nullptr; size_t memh_size = 0; ucc_mem_map_params_t params; ucc_mem_map_t segment; EXPECT_EQ(UCC_ERR_INVALID_PARAM, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_IMPORT, nullptr, &memh_size, &memh)); /* Test import with NULL memh */ segment.address = test_buffer; segment.len = buffer_size; params.segments = &segment; params.n_segments = 1; EXPECT_EQ(UCC_ERR_INVALID_PARAM, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_IMPORT, ¶ms, &memh_size, nullptr)); } /* Test memory map unmap with NULL handle */ UCC_TEST_F(test_mem_map_export, unmap_null) { ucc_mem_map_mem_h memh = nullptr; /* Should handle NULL gracefully */ EXPECT_EQ(UCC_ERR_INVALID_PARAM, ucc_mem_unmap(&memh)); } /* Test memory map with different modes */ UCC_TEST_F(test_mem_map_export, different_modes) { ucc_mem_map_mem_h memh1 = nullptr; size_t memh_size = 0; /* Test EXPORT mode */ EXPECT_EQ(UCC_OK, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh1)); EXPECT_NE(nullptr, memh1); EXPECT_EQ(UCC_OK, ucc_mem_unmap(&memh1)); } /* Test memory map stress test */ UCC_TEST_F(test_mem_map_export, stress_test) { const int num_iterations = 100; ucc_mem_map_mem_h memh; size_t memh_size; for (int i = 0; i < num_iterations; i++) { memh = nullptr; memh_size = 0; /* Fill buffer with iteration-specific pattern */ memset(test_buffer, i % 256, buffer_size); EXPECT_EQ(UCC_OK, ucc_mem_map(ctx_h, UCC_MEM_MAP_MODE_EXPORT, &map_params, &memh_size, &memh)); EXPECT_NE(nullptr, memh); EXPECT_EQ(UCC_OK, ucc_mem_unmap(&memh)); } } openucx-ucc-ec0bc8a/test/gtest/core/test_service_coll.cc0000664000175000017500000001461415133731560023764 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include #include #include extern "C" { #include "core/ucc_team.h" #include "core/ucc_service_coll.h" } class test_service_coll { public: int * array; UccTeam_h team; std::vector subsets; std::vector reqs; test_service_coll(std::vector _subset, UccTeam_h _team) { team = _team; array = new int[_subset.size()]; memcpy(array, _subset.data(), sizeof(int) * _subset.size()); subsets.resize(_subset.size()); reqs.resize(_subset.size()); for (auto i = 0; i < _subset.size(); i++) { subsets[i].myrank = i; subsets[i].map.type = UCC_EP_MAP_ARRAY; subsets[i].map.array.map = (void *)array; subsets[i].map.array.elem_size = sizeof(int); subsets[i].map.ep_num = _subset.size(); } } ~test_service_coll() { delete[] array; } void progress() { for (auto i = 0; i < reqs.size(); i++) { ucc_context_progress(team.get()->procs[array[i]].p->ctx_h); } } void wait() { ucc_status_t status; bool ready; do { ready = true; progress(); for (auto &r : reqs) { status = ucc_service_coll_test(r); EXPECT_GE(status, 0); if (UCC_INPROGRESS == status) { ready = false; } } } while (!ready); for (auto &r : reqs) { ucc_service_coll_finalize(r); } } }; class test_service_allreduce : public test_service_coll { std::vector> sbuf; std::vector> rbuf; public: test_service_allreduce(std::vector _subset, size_t count, UccTeam_h _team) : test_service_coll(_subset, _team) { sbuf.resize(_subset.size()); rbuf.resize(_subset.size()); for (auto i = 0; i < _subset.size(); i++) { sbuf[i].resize(count); rbuf[i].resize(count); for (auto j = 0; j < count; j++) { sbuf[i][j] = i + j + 1; rbuf[i][j] = 0; } } } void start() { ucc_status_t status; for (auto i = 0; i < reqs.size(); i++) { auto r = array[i]; status = ucc_service_allreduce( team.get()->procs[r].team, sbuf[i].data(), rbuf[i].data(), UCC_DT_INT32, sbuf[i].size(), UCC_OP_SUM, subsets[i], &reqs[i]); EXPECT_EQ(UCC_OK, status); } } void check() { int size = reqs.size(); for (auto i = 0; i < size; i++) { for (auto j = 0; j < sbuf[i].size(); j++) { int check = size * (size + 1) / 2 + j * size; EXPECT_EQ(check, rbuf[i][j]); } } } }; class test_scoll_allreduce : public ucc::test, public ::testing::WithParamInterface> { }; UCC_TEST_P(test_scoll_allreduce, allreduce) { /* Reversed team of size staticUccJobSize - last one in static teawms array */ auto team = UccJob::getStaticTeams().back(); ASSERT_EQ(team.get()->procs.size(), 16); std::vector subset = GetParam(); test_service_allreduce t(subset, 4, team); t.start(); t.wait(); t.check(); } INSTANTIATE_TEST_CASE_P( , test_scoll_allreduce, ::testing::Values(std::vector({1, 0}), std::vector({2, 3}), std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8}), std::vector({15, 14, 13, 12, 11, 10, 9}), std::vector({0, 2, 4, 6, 8}))); class test_service_allgather : public test_service_coll { std::vector> sbuf; std::vector> rbuf; public: test_service_allgather(std::vector _subset, size_t count, UccTeam_h _team) : test_service_coll(_subset, _team) { sbuf.resize(_subset.size()); rbuf.resize(_subset.size()); for (auto i = 0; i < _subset.size(); i++) { sbuf[i].resize(count); rbuf[i].resize(count * _subset.size()); for (auto j = 0; j < count; j++) { sbuf[i][j] = i + j + 1; } for (auto j = 0; j < count * _subset.size(); j++) { rbuf[i][j] = 0; } } } void start() { ucc_status_t status; for (auto i = 0; i < reqs.size(); i++) { auto r = array[i]; status = ucc_service_allgather( team.get()->procs[r].team, sbuf[i].data(), rbuf[i].data(), sbuf[i].size() * sizeof(int), subsets[i], &reqs[i]); EXPECT_EQ(UCC_OK, status); } } void check() { int size = reqs.size(); int count = sbuf[0].size(); for (auto i = 0; i < size; i++) { for (auto j = 0; j < rbuf[i].size(); j++) { int check = (j % count) + 1 + (j / count); EXPECT_EQ(check, rbuf[i][j]); } } } }; class test_scoll_allgather : public ucc::test, public ::testing::WithParamInterface> { }; UCC_TEST_P(test_scoll_allgather, allgather) { /* Reversed team of size staticUccJobSize - last one in static teawms array */ auto team = UccJob::getStaticTeams().back(); ASSERT_EQ(team.get()->procs.size(), 16); std::vector subset = GetParam(); test_service_allgather t(subset, 4, team); t.start(); t.wait(); t.check(); } INSTANTIATE_TEST_CASE_P( , test_scoll_allgather, ::testing::Values(std::vector({1, 0}), std::vector({2, 3}), std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8}), std::vector({15, 14, 13, 12, 11, 10, 9}), std::vector({0, 2, 4, 6, 8}))); openucx-ucc-ec0bc8a/test/gtest/core/test_lib.cc0000664000175000017500000000254015133731560022054 0ustar alastairalastair/** * Copyright (c) 2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include #include #include class test_lib : public ucc::test { }; UCC_TEST_F(test_lib, init_finalize) { ucc_lib_config_h cfg; ucc_lib_params_t lib_params; ucc_lib_h lib; EXPECT_EQ(UCC_OK, ucc_lib_config_read(NULL, NULL, &cfg)); lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; lib_params.thread_mode = UCC_THREAD_SINGLE; EXPECT_EQ(UCC_OK, ucc_init(&lib_params, cfg, &lib)); ucc_lib_config_release(cfg); EXPECT_EQ(UCC_OK, ucc_finalize(lib)); } UCC_TEST_F(test_lib, init_multiple) { const int n_libs = 8; ucc_lib_config_h cfg; ucc_lib_params_t lib_params; ucc_lib_h lib_h; std::vector libs; EXPECT_EQ(UCC_OK, ucc_lib_config_read(NULL, NULL, &cfg)); lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; lib_params.thread_mode = UCC_THREAD_SINGLE; for (int i = 0; i < n_libs; i++) { EXPECT_EQ(UCC_OK, ucc_init(&lib_params, cfg, &lib_h)); libs.push_back(lib_h); } ucc_lib_config_release(cfg); std::shuffle(libs.begin(), libs.end(), std::default_random_engine()); for (auto lib_h : libs) { EXPECT_EQ(UCC_OK, ucc_finalize(lib_h)); } } openucx-ucc-ec0bc8a/test/gtest/core/test_mem_map.h0000664000175000017500000000231615133731560022564 0ustar alastairalastair/** * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #ifndef TEST_MEM_MAP_H #define TEST_MEM_MAP_H #include "../common/test_ucc.h" #include "test_context.h" #include #include class test_mem_map : public test_context_config { protected: ucc_context_h ctx_h; ucc_context_params_t ctx_params; ucc_context_config_h ctx_config; public: test_mem_map(); ~test_mem_map(); void SetUp() override; void TearDown() override; }; class test_mem_map_export : public test_mem_map { protected: void * test_buffer; size_t buffer_size; ucc_mem_map_params_t map_params; ucc_mem_map_t segment; public: test_mem_map_export(); ~test_mem_map_export(); void SetUp() override; void TearDown() override; }; class test_mem_map_import : public test_mem_map { protected: void * test_buffer; size_t buffer_size; ucc_mem_map_mem_h memh; size_t memh_size; public: test_mem_map_import(); ~test_mem_map_import(); void SetUp() override; void TearDown() override; }; #endif /* TEST_MEM_MAP_H */ openucx-ucc-ec0bc8a/test/gtest/core/test_schedule.cc0000664000175000017500000000702715133731560023107 0ustar alastairalastair/** * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include extern "C" { #include "schedule/ucc_schedule.h" } class test_coll_task : public ucc_coll_task_t { public: test_coll_task() { ucc_coll_task_construct(this); EXPECT_EQ(UCC_OK, ucc_coll_task_init((ucc_coll_task_t *)this, NULL, NULL)); } ~test_coll_task() { ucc_coll_task_destruct(this); } }; typedef std::tuple rst_t; class test_schedule : public test_coll_task, public ucc::test { public: std::vector rst; static ucc_status_t handler_1(ucc_coll_task_t *parent, ucc_coll_task_t *task) { test_schedule *ts = (test_schedule*)task; ts->rst.push_back(rst_t((test_coll_task*)parent, 1)); return UCC_OK; } static ucc_status_t handler_2(ucc_coll_task_t *parent, ucc_coll_task_t *task) { test_schedule *ts = (test_schedule*)task; ts->rst.push_back(rst_t((test_coll_task*)parent, 2)); return UCC_OK; } }; /* Tasks subscribes on 2 tasks to EVENT_COMPLETED with the same handler */ UCC_TEST_F(test_schedule, single_handler) { std::vector tasks(2); for (auto &t : tasks) { ucc_event_manager_subscribe(&t, UCC_EVENT_COMPLETED, (ucc_coll_task_t*)this, test_schedule::handler_1); } for (auto &t : tasks) { EXPECT_EQ(UCC_OK, ucc_event_manager_notify(&t, UCC_EVENT_COMPLETED)); } EXPECT_EQ(2, rst.size()); EXPECT_EQ(true, (std::get<0>(rst[0]) == &tasks[0]) && (std::get<1>(rst[0]) == 1)); EXPECT_EQ(true, (std::get<0>(rst[1]) == &tasks[1]) && (std::get<1>(rst[1]) == 1)); } /* Tasks subscribes on 2 tasks to EVENT_COMPLETED with 2 different handlers */ UCC_TEST_F(test_schedule, different_handlers) { std::vector tasks(2); ucc_event_manager_subscribe(&tasks[0], UCC_EVENT_COMPLETED, (ucc_coll_task_t*)this, test_schedule::handler_1); ucc_event_manager_subscribe(&tasks[1], UCC_EVENT_COMPLETED, (ucc_coll_task_t*)this, test_schedule::handler_2); for (auto &t : tasks) { EXPECT_EQ(UCC_OK, ucc_event_manager_notify(&t, UCC_EVENT_COMPLETED)); } EXPECT_EQ(2, rst.size()); EXPECT_EQ(true, (std::get<0>(rst[0]) == &tasks[0]) && (std::get<1>(rst[0]) == 1)); EXPECT_EQ(true, (std::get<0>(rst[1]) == &tasks[1]) && (std::get<1>(rst[1]) == 2)); } /* Tasks subscribes to multiple tasks exceeding MAX_LISTENERS */ UCC_TEST_F(test_schedule, multiple) { const int n_subscribers = 16; std::vector tasks(n_subscribers); for (int i = 0; i < n_subscribers; i++) { ucc_event_manager_subscribe(&tasks[i], UCC_EVENT_COMPLETED, (ucc_coll_task_t*)this, ((i % 2) == 0 ? test_schedule::handler_1 : test_schedule::handler_2)); } for (auto &t : tasks) { EXPECT_EQ(UCC_OK, ucc_event_manager_notify(&t, UCC_EVENT_COMPLETED)); } EXPECT_EQ(n_subscribers, rst.size()); for (int i = 0; i < n_subscribers; i++) { EXPECT_EQ(true, (std::get<0>(rst[i]) == &tasks[i]) && (std::get<1>(rst[i]) == ((i % 2) + 1))); } } openucx-ucc-ec0bc8a/test/gtest/core/test_mc_reduce.cc0000664000175000017500000003765015133731560023246 0ustar alastairalastair/** * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "test_mc_reduce.h" extern "C" { #include "components/ec/ucc_ec.h" } #ifdef HAVE_CUDA #include #endif template class test_mc_reduce : public testing::Test { protected: const int COUNT = 1024; ucc_memory_type_t mem_type; ucc_ee_executor_t *executor; void *ee_context = NULL; virtual void SetUp() override { ucc_constructor(); ucc_mc_params_t mc_params = { .thread_mode = UCC_THREAD_SINGLE, }; ucc_ec_params_t ec_params = { .thread_mode = UCC_THREAD_SINGLE, }; ucc_mc_init(&mc_params); ucc_ec_init(&ec_params); buf1_h = buf2_h = res_h = nullptr; buf1_d = buf2_d = res_d = nullptr; executor = nullptr; } ucc_status_t alloc_executor(ucc_memory_type_t mtype) { ucc_ee_executor_params_t params; ucc_ee_type_t coll_ee_type; ucc_status_t status; switch (mtype) { case UCC_MEMORY_TYPE_CUDA: coll_ee_type = UCC_EE_CUDA_STREAM; #ifdef HAVE_CUDA if (triggered) { cudaStream_t stream; if (cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking) != cudaSuccess) { std::cerr << "failed to create cuda stream" << std::endl; return UCC_ERR_NO_RESOURCE; } ee_context = (void *)stream; } #endif break; case UCC_MEMORY_TYPE_CUDA_MANAGED: coll_ee_type = UCC_EE_CUDA_STREAM; break; case UCC_MEMORY_TYPE_HOST: coll_ee_type = UCC_EE_CPU_THREAD; break; default: std::cerr << "invalid executor mem type\n"; return UCC_ERR_INVALID_PARAM; break; } params.mask = UCC_EE_EXECUTOR_PARAM_FIELD_TYPE; params.ee_type = coll_ee_type; status = ucc_ee_executor_init(¶ms, &executor); if (UCC_OK != status) { std::cerr << "failed to init executor: " << ucc_status_string(status) << std::endl; return status; } status = ucc_ee_executor_start(executor, ee_context); if (UCC_OK != status) { std::cerr << "failed to start executor: " << ucc_status_string(status) << std::endl; ucc_ee_executor_finalize(executor); } return status; } ucc_status_t free_executor() { ucc_status_t status; status = ucc_ee_executor_stop(executor); if (UCC_OK != status) { std::cerr << "failed to stop executor: " << ucc_status_string(status) << std::endl; } ucc_ee_executor_finalize(executor); #ifdef HAVE_CUDA if (triggered) { if (cudaStreamDestroy((cudaStream_t)ee_context) != cudaSuccess) { std::cerr << "failed to destory cuda stream" << std::endl; return UCC_ERR_NO_MESSAGE; } ee_context = NULL; } #endif return status; } ucc_status_t setup(ucc_memory_type_t mtype, size_t n) { ucc_status_t status; status = alloc_bufs(mtype, n); if (UCC_OK != status) { return status; } return alloc_executor(mtype); } ucc_status_t alloc_bufs(ucc_memory_type_t mtype, size_t n) { size_t n_bytes = COUNT*sizeof(typename T::type); mem_type = mtype; ucc_mc_alloc(&res_h_mc_header, n_bytes, UCC_MEMORY_TYPE_HOST); res_h = (typename T::type *)res_h_mc_header->addr; ucc_mc_alloc(&buf1_h_mc_header, n_bytes, UCC_MEMORY_TYPE_HOST); buf1_h = (typename T::type *)buf1_h_mc_header->addr; ucc_mc_alloc(&buf2_h_mc_header, n * n_bytes, UCC_MEMORY_TYPE_HOST); buf2_h = (typename T::type *)buf2_h_mc_header->addr; for (int i = 0; i < COUNT; i++) { res_h[i] = (typename T::type)(0); } for (int i = 0; i < COUNT; i++) { /* bFloat16 will be assigned with the floats matching the uint16_t bit pattern*/ buf1_h[i] = (typename T::type)(i + 1); } for (int j = 0; j < n; j++) { for (int i = 0; i < COUNT; i++) { buf2_h[i + j * COUNT] = (typename T::type)(2 * i + j + 1); } } if (mtype != UCC_MEMORY_TYPE_HOST) { ucc_mc_alloc(&res_d_mc_header, n_bytes, mtype); res_d = (typename T::type *)res_d_mc_header->addr; ucc_mc_alloc(&buf1_d_mc_header, n_bytes, mtype); buf1_d = (typename T::type *)buf1_d_mc_header->addr; ucc_mc_alloc(&buf2_d_mc_header, n * n_bytes, mtype); buf2_d = (typename T::type *)buf2_d_mc_header->addr; ucc_mc_memcpy(res_d, res_h, n_bytes, mtype, UCC_MEMORY_TYPE_HOST); ucc_mc_memcpy(buf1_d, buf1_h, n_bytes, mtype, UCC_MEMORY_TYPE_HOST); ucc_mc_memcpy(buf2_d, buf2_h, n * n_bytes, mtype, UCC_MEMORY_TYPE_HOST); buf1 = buf1_d; buf2 = buf2_d; res = res_d; } else { buf1 = buf1_h; buf2 = buf2_h; res = res_h; } return UCC_OK; } ucc_status_t free_bufs(ucc_memory_type_t mtype) { if (buf1_h != nullptr) { ucc_mc_free(buf1_h_mc_header); } if (buf2_h != nullptr) { ucc_mc_free(buf2_h_mc_header); } if (res_h != nullptr) { ucc_mc_free(res_h_mc_header); } if (buf1_d != nullptr) { ucc_mc_free(buf1_d_mc_header); } if (buf2_d != nullptr) { ucc_mc_free(buf2_d_mc_header); } if (res_d != nullptr) { ucc_mc_free(res_d_mc_header); } return UCC_OK; } virtual void TearDown() override { free_bufs(mem_type); ucc_mc_finalize(); } ucc_status_t do_reduce(void *src1, void *src2, void *dst, size_t count, uint16_t n_src2, size_t stride, ucc_datatype_t dt, ucc_reduction_op_t op, bool with_alpha, double alpha) { ucc_ee_executor_task_args_t eargs; ucc_status_t status; ucc_ee_executor_task_t * task; eargs.flags = with_alpha ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0; eargs.task_type = UCC_EE_EXECUTOR_TASK_REDUCE_STRIDED; eargs.reduce_strided.count = count; eargs.reduce_strided.dt = dt; eargs.reduce_strided.op = op; eargs.reduce_strided.n_src2 = n_src2; eargs.reduce_strided.dst = dst; eargs.reduce_strided.src1 = src1; eargs.reduce_strided.src2 = src2; eargs.reduce_strided.stride = stride; eargs.reduce_strided.alpha = alpha; status = ucc_ee_executor_task_post(executor, &eargs, &task); if (UCC_OK != status) { std::cerr << "failed to post executor task: " << ucc_status_string(status) << std::endl; return status; } while (0 < (status = ucc_ee_executor_task_test(task))) { ; } ucc_ee_executor_task_finalize(task); return status; } void test_reduce(ucc_memory_type_t mt) { ucc_status_t status; if (UCC_OK != ucc_mc_available(mt)) { GTEST_SKIP(); } ASSERT_EQ(this->setup(mt, 1), UCC_OK); status = do_reduce(this->buf1, this->buf2, this->res, this->COUNT, 1, 0, T::dt, T::redop, false, 0); if (UCC_ERR_NOT_SUPPORTED == status) { GTEST_SKIP(); } ASSERT_EQ(status, UCC_OK); if (executor) { free_executor(); } if (mt != UCC_MEMORY_TYPE_HOST) { ucc_mc_memcpy(this->res_h, this->res_d, this->COUNT * sizeof(*this->res_d), UCC_MEMORY_TYPE_HOST, mt); } for (int i = 0; i < this->COUNT; i++) { T::assert_equal(T::do_op(this->buf1_h[i], this->buf2_h[i]), this->res_h[i]); } }; void test_reduce_multi(ucc_memory_type_t mt) { const int num_vec = 3; ucc_status_t status; if (UCC_OK != ucc_mc_available(mt)) { GTEST_SKIP(); } ASSERT_EQ(this->setup(mt, num_vec), UCC_OK); status = do_reduce(this->buf1, this->buf2, this->res, this->COUNT, num_vec, this->COUNT * sizeof(*this->buf2), T::dt, T::redop, false, 0); if (UCC_ERR_NOT_SUPPORTED == status) { GTEST_SKIP(); } ASSERT_EQ(status, UCC_OK); if (executor) { free_executor(); } if (mt != UCC_MEMORY_TYPE_HOST) { ucc_mc_memcpy(this->res_h, this->res_d, this->COUNT * sizeof(*this->res_d), UCC_MEMORY_TYPE_HOST, mt); } for (int i = 0; i < this->COUNT; i++) { typename T::type res = T::do_op(this->buf1_h[i], this->buf2_h[i]); for (int j = 1; j < num_vec; j++) { res = T::do_op(this->buf2_h[i + j * this->COUNT], res); } T::assert_equal(res, this->res_h[i]); } }; void test_reduce_multi_alpha(ucc_memory_type_t mt) { const int num_vec = 20; const double alpha = 0.7; ucc_status_t status; if (UCC_OK != ucc_mc_available(mt)) { GTEST_SKIP(); } ASSERT_EQ(UCC_OK, this->setup(mt, num_vec)); status = do_reduce(this->buf1, this->buf2, this->res, this->COUNT, num_vec, this->COUNT * sizeof(*this->buf2), T::dt, T::redop, true, alpha); if (UCC_ERR_NOT_SUPPORTED == status) { GTEST_SKIP(); } ASSERT_EQ(status, UCC_OK); if (executor) { free_executor(); } if (mt != UCC_MEMORY_TYPE_HOST) { ucc_mc_memcpy(this->res_h, this->res_d, this->COUNT * sizeof(*this->res_d), UCC_MEMORY_TYPE_HOST, mt); } for (int i = 0; i < this->COUNT; i++) { typename T::type res = T::do_op(this->buf1_h[i], this->buf2_h[i]); for (int j = 1; j < num_vec; j++) { res = T::do_op(this->buf2_h[i + j * this->COUNT], res); } if (T::dt == UCC_DT_BFLOAT16) { float32tobfloat16(bfloat16tofloat32(&res)*(float)alpha, &res); } else { res *= (typename T::type)alpha; } T::assert_equal(res, this->res_h[i]); } } ucc_mc_buffer_header_t *buf1_h_mc_header, *buf2_h_mc_header, *res_h_mc_header, *buf1_d_mc_header, *buf2_d_mc_header, *res_d_mc_header; typename T::type *buf1_h; typename T::type *buf2_h; typename T::type *res_h; typename T::type *buf1_d; typename T::type *buf2_d; typename T::type *res_d; typename T::type *buf1; typename T::type *buf2; typename T::type *res; }; #define INT_OP_PAIRS(_TYPE) ARITHMETIC_OP_PAIRS(_TYPE), \ TypeOpPair, \ TypeOpPair, \ TypeOpPair, \ TypeOpPair, \ TypeOpPair, \ TypeOpPair using TypeOpPairsInt = ::testing::Types; using TypeOpPairsUint = ::testing::Types; using TypeOpPairsFloat = ::testing::Types, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair>; using TypeOpPairsFloatCuda = ::testing::Types< ARITHMETIC_OP_PAIRS(FLOAT32), ARITHMETIC_OP_PAIRS(FLOAT64), ARITHMETIC_OP_PAIRS(BFLOAT16), TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair, TypeOpPair>; template class test_mc_reduce_int : public test_mc_reduce {}; TYPED_TEST_CASE(test_mc_reduce_int, TypeOpPairsInt); template class test_mc_reduce_uint : public test_mc_reduce {}; TYPED_TEST_CASE(test_mc_reduce_uint, TypeOpPairsUint); template class test_mc_reduce_float : public test_mc_reduce {}; TYPED_TEST_CASE(test_mc_reduce_float, TypeOpPairsFloat); #define DECLARE_REDUCE_TEST(_type, _mt) \ TYPED_TEST(test_mc_reduce_ ## _type, _mt) { \ this->test_reduce(UCC_MEMORY_TYPE_ ## _mt); \ } \ #define DECLARE_REDUCE_MULTI_TEST(_type, _mt) \ TYPED_TEST(test_mc_reduce_ ## _type, multi_ ## _mt) { \ this->test_reduce_multi(UCC_MEMORY_TYPE_ ## _mt); \ } \ #define DECLARE_REDUCE_MULTI_ALPHA_TEST(_type, _mt) \ TYPED_TEST(test_mc_reduce_ ## _type, multi_alpha_ ## _mt) { \ this->test_reduce_multi_alpha(UCC_MEMORY_TYPE_ ## _mt); \ } \ DECLARE_REDUCE_TEST(int, HOST); DECLARE_REDUCE_TEST(uint, HOST); DECLARE_REDUCE_TEST(float, HOST); DECLARE_REDUCE_MULTI_TEST(int, HOST); DECLARE_REDUCE_MULTI_TEST(uint, HOST); DECLARE_REDUCE_MULTI_TEST(float, HOST); DECLARE_REDUCE_MULTI_ALPHA_TEST(float, HOST); #ifdef HAVE_CUDA DECLARE_REDUCE_TEST(int, CUDA); DECLARE_REDUCE_TEST(uint, CUDA); DECLARE_REDUCE_TEST(float, CUDA); DECLARE_REDUCE_MULTI_TEST(int, CUDA); DECLARE_REDUCE_MULTI_TEST(uint, CUDA); DECLARE_REDUCE_MULTI_TEST(float, CUDA); DECLARE_REDUCE_MULTI_ALPHA_TEST(float, CUDA); template class test_mc_reduce_int_triggered : public test_mc_reduce {}; TYPED_TEST_CASE(test_mc_reduce_int_triggered, TypeOpPairsInt); template class test_mc_reduce_uint_triggered : public test_mc_reduce {}; TYPED_TEST_CASE(test_mc_reduce_uint_triggered, TypeOpPairsUint); template class test_mc_reduce_float_triggered : public test_mc_reduce {}; TYPED_TEST_CASE(test_mc_reduce_float_triggered, TypeOpPairsFloatCuda); DECLARE_REDUCE_TEST(int_triggered, CUDA); DECLARE_REDUCE_TEST(uint_triggered, CUDA); DECLARE_REDUCE_TEST(float_triggered, CUDA); DECLARE_REDUCE_MULTI_TEST(int_triggered, CUDA); DECLARE_REDUCE_MULTI_TEST(uint_triggered, CUDA); DECLARE_REDUCE_MULTI_TEST(float_triggered, CUDA); DECLARE_REDUCE_MULTI_ALPHA_TEST(float_triggered, CUDA); #endif openucx-ucc-ec0bc8a/test/gtest/core/test_mc_reduce.h0000664000175000017500000001674415133731560023111 0ustar alastairalastair/** * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ extern "C" { #include #include } #include template class op> struct TypeOpPair; #define DECLARE_TYPE_OP_PAIR(_type, _TYPE, _EQ) \ template