diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index db72b09..1417e48 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -6,6 +6,8 @@ PRIVATE main.cpp yycc/string/op.cpp + + yycc/carton/fft.cpp ) # target_sources(YYCCBenchmark # PRIVATE diff --git a/benchmark/yycc/carton/fft.cpp b/benchmark/yycc/carton/fft.cpp new file mode 100644 index 0000000..0e7bcc2 --- /dev/null +++ b/benchmark/yycc/carton/fft.cpp @@ -0,0 +1,40 @@ +#include +#include +#include +#include +#include + +#define FFT ::yycc::carton::fft + +namespace yyccbench::carton::fft { + + using TIndex = size_t; + using TFloat = float; + using TComplex = std::complex; + template + using TFft = FFT::Fft; + + constexpr TIndex FFT_POINTS = 1024u; + + static void BM_FftCompute(benchmark::State& state) { + // prepare random buffer + constexpr TIndex RND_BUF_CNT = 8u; + std::random_device rnd_device; + std::default_random_engine rnd_engine(rnd_device()); + std::uniform_real_distribution rnd_dist(0.0f, 1.0f); + std::vector> buffer_collection(RND_BUF_CNT); + for (auto& buf : buffer_collection) { + buf.resize(FFT_POINTS); + std::generate(buf.begin(), buf.end(), [&rnd_engine, &rnd_dist]() mutable -> TComplex { return TComplex(rnd_dist(rnd_engine)); }); + } + + // prepare FFT engine + TFft fft; + // do benchmark + for (auto _ : state) { + fft.compute(buffer_collection[state.iterations() % RND_BUF_CNT].data()); + } + } + BENCHMARK(BM_FftCompute)->Name("FftCompute"); + +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 41a729c..f18a32f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -95,6 +95,7 @@ FILES yycc/carton/clap/summary.hpp yycc/carton/clap/application.hpp yycc/carton/clap/manual.hpp + yycc/carton/fft.hpp ) # Setup header infomations target_include_directories(YYCCommonplace diff --git a/src/yycc/carton/fft.hpp b/src/yycc/carton/fft.hpp new file mode 100644 index 0000000..52f37b6 --- /dev/null +++ b/src/yycc/carton/fft.hpp @@ -0,0 +1,307 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace yycc::carton::fft { + + /// @private + /// @brief Meta-programming utilities for FFT modules. + namespace util { + + template + inline constexpr TFloat tau_v = static_cast(2) * std::numbers::pi_v; + + // NOTE: + // We use std::has_single_bit() to check whether given number is an integral power of 2. + // And use (std::bit_width() - 1) to get the exponent of given number based on 2. + + template + struct validate_args { + private: + static constexpr bool is_unsigned_int = std::is_unsigned_v && std::is_integral_v; + static constexpr bool is_float_point = std::is_floating_point_v; + static constexpr bool n_is_pow_2 = std::has_single_bit(static_cast(N)) && N >= static_cast(2); + + public: + static constexpr bool value = is_unsigned_int && is_float_point && n_is_pow_2; + }; + + template + inline constexpr bool validate_args_v = validate_args::value; + + } // namespace util + +#pragma region Window + + enum class WindowType { HanningWindow }; + + template + requires util::validate_args_v + class Window { + private: + static constexpr TIndex N = N; + + public: + Window(WindowType win_type) : window_type(win_type), window_data(nullptr) { + // Pre-compute window data + // Allocate window buffer + window_data = std::make_unique(N); + // Assign window data + switch (win_type) { + case WindowType::HanningWindow: + for (TIndex i = 0u; i < N; ++i) { + window_data[i] = static_cast(0.5) + * (static_cast(1) + - std::cos(util::tau_v + * static_cast(i) / static_cast(N - static_cast(1)))); + } + break; + default: + throw std::invalid_argument("invalid window function type"); + } + } + + private: + WindowType window_type; + std::unique_ptr window_data; + + public: + /** + * @brief Apply window function to given data sequence. + * @param[in,out] data + * The float-point data sequence for applying window function. + * The length of this sequence must be N. + */ + void apply_window(TFloat* data) const { + if (data == nullptr) [[unlikely]] { + throw std::invalid_argument("nullptr data is not allowed for applying window."); + } + for (TIndex i = static_cast(0); i < N; ++i) { + data[i] *= window_data[i]; + } + } + /** + * @brief Get underlying window function data for custom applying. + * @return + * The pointer to the start address of underlying window function data sequence. + * The length of this sequence is N. + */ + const TFloat* get_window_data() const { return window_data.get(); } + }; + +#pragma endregion + +#pragma region FFT + + template + requires util::validate_args_v + struct FftProperties { + public: + using TComplex = std::complex; + static constexpr TIndex N = static_cast(N); + static constexpr TIndex M = static_cast(std::bit_width(N) - 1); + static constexpr TIndex HALF_POINT = N >> static_cast(1); + }; + + /** + * @brief The core FFT class. + * @details The core class implementing FFT algorithm (base-2 version). + * @tparam TIndex + * @tparam TFloat + * @tparam N + */ + template + requires util::validate_args_v + class Fft { + private: + using TProperties = FftProperties; + using TComplex = TProperties::TComplex; + static constexpr TIndex N = TProperties::N; + static constexpr TIndex M = TProperties::M; + static constexpr TIndex HALF_POINT = TProperties::HALF_POINT; + + public: + Fft() : wnp_cache(nullptr) { + // Generate WNP cache + wnp_cache = std::make_unique(N); + for (TIndex P = static_cast(0); P < N; ++P) { + TFloat angle = util::tau_v * static_cast(P) / static_cast(N); + // e^(-jx) = cosx - j sinx + wnp_cache[P] = TComplex(std::cos(angle), -std::sin(angle)); + } + } + + private: + std::unique_ptr wnp_cache; + + public: + /** + * @brief Compute FFT for given complex sequence. + * @details + * This is FFT core compute function but not suit for common user + * because it order that you have enough FFT knowledge to understand what is input data and what is output data. + * For convenient use, see also easy_compute(). + * @param[in,out] data + * The complex sequence for computing. + * The length of this sequence must be N. + */ + void compute(TComplex* data) const { + if (data == nullptr) [[unlikely]] { + throw std::invalid_argument("nullptr data is not allowed for FFT computing."); + } + + TIndex LH, J, K, B, P; + LH = J = HALF_POINT; + + // Construct butterfly structure + for (TIndex I = static_cast(1); I <= N - static_cast(2); ++I) { + if (I < J) std::swap(data[I], data[J]); + + K = LH; + while (J >= K) { + J -= K; + K >>= static_cast(1); + } + J += K; + } + + // Calculate butterfly + TComplex temp, temp2; + for (TIndex L = static_cast(1); L <= M; ++L) { + B = static_cast(1u) << (L - static_cast(1)); + for (J = static_cast(0); J <= B - static_cast(1); ++J) { + P = J * (static_cast(1) << (M - L)); + + // Use pre-computed cache instead of real-time computing + for (TIndex KK = J; KK <= N - static_cast(1); KK += (static_cast(1) << L)) { + temp2 = (data[KK + B] * this->wnp_cache[P]); + temp = temp2 + data[KK]; + data[KK + B] = data[KK] - temp2; + data[KK] = temp; + } + } + } + } + }; + + /** + * @brief User friendly FFT computation class. + * @details + * @tparam TIndex + * @tparam TFloat + * @tparam N + * @warning This class is \b NOT thread safe. Please use different instance in different thread. + */ + template + requires util::validate_args_v + class FriendlyFft { + private: + using UnderlyingFft = Fft; + using TProperties = FftProperties; + using TComplex = TProperties::TComplex; + static constexpr TIndex N = TProperties::N; + static constexpr TIndex M = TProperties::M; + static constexpr TIndex HALF_POINT = TProperties::HALF_POINT; + + public: + FriendlyFft() : compute_cache(N) { + // Initialize computation used buffer. + compute_cache = std::vector(); + } + + private: + UnderlyingFft underlying_fft; + std::vector compute_cache; + + public: + /** + * @brief Get the maximum frequency by given sample rate. + * @param[in] sample_rate + * The sample rate of input stream. + * Unit is Hz or SPS (sample point per second). + * @return + * The last data in computed FFT drequency data represented frequency. + * Unit is Hz. + */ + TFloat get_max_freq(TFloat sample_rate) { + // Following sample priniciple + return sample_rate / static_cast(2); + } + + /** + * @brief Compute FFT for given time scope data. + * @details + * This is convenient FFT compute function, comparing with compute(). + * This function accepts time scope data and output frequency scope data automatically. + * Additionally, it order a window function instance to apply to time scope data before computing. + * @param[in] time_scope The length of this data must be N. + * For the time order of data, the first data should be the oldest data and the last data should be the newest data. + * @param[out] freq_scope The length of this data must be N / 2. + * The first data is 0Hz and the frequency of last data is decided by sample rate which can be computed by get_max_freq() function in this class. + * @param[in] window The window instance applied to data. + * @warnings + * This function is \b NOT thread-safe. + * Please do NOT call this function in different thread for one instance. + */ + void easy_compute(const TFloat* time_scope, TFloat* freq_scope, const Window& window) { + if (time_scope == nullptr || freq_scope == nullptr) [[unlikely]] { + throw std::invalid_argument("nullptr data is not allowed for easy FFT computing."); + } + + // First, we copy time scope data into cache with reversed order. + // because FFT order the first item should be the latest data. + // At the same time we multiple it with window function. + std::generate(compute_cache.begin(), + compute_cache.end(), + [data = &(time_scope[N]), win_data = window.get_window_data()]() mutable -> TComplex { + return TComplex(*(data--) * *(win_data++)); + }); + + // Do FFT compute + underlying_fft.compute(compute_cache.data()); + + // Compute amplitude + for (TIndex i = static_cast(0); i < HALF_POINT; ++i) { + freq_scope[i] = static_cast(10) * std::log10(std::abs(compute_cache[i + HALF_POINT])); + } + } + }; + +#pragma endregion + +#pragma region Pre-defined FFT Types + + using Fft4F = Fft; + using Fft8F = Fft; + using Fft16F = Fft; + using Fft32F = Fft; + using Fft64F = Fft; + using Fft128F = Fft; + using Fft256F = Fft; + using Fft512F = Fft; + using Fft1024F = Fft; + using Fft2048F = Fft; + + using Fft4 = Fft; + using Fft8 = Fft; + using Fft16 = Fft; + using Fft32 = Fft; + using Fft64 = Fft; + using Fft128 = Fft; + using Fft256 = Fft; + using Fft512 = Fft; + using Fft1024 = Fft; + using Fft2048 = Fft; + +#pragma endregion + +} // namespace yycc::carton::fft diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9470dcb..f0e848c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -41,6 +41,7 @@ PRIVATE yycc/carton/wcwidth.cpp yycc/carton/tabulate.cpp yycc/carton/clap.cpp + yycc/carton/fft.cpp ) target_sources(YYCCTest PRIVATE diff --git a/test/yycc/carton/fft.cpp b/test/yycc/carton/fft.cpp new file mode 100644 index 0000000..2a7c6ea --- /dev/null +++ b/test/yycc/carton/fft.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include + +#define FFT ::yycc::carton::fft + +namespace yycctest::carton::fft { + + using TIndex = size_t; + using TFloat = float; + using TComplex = std::complex; + template + using TFft = FFT::Fft; + + // YYC MARK: + // It seems that default epsilon can not fulfill our test (too small). + constexpr TFloat TOLERANCE = static_cast(0.0003); + //constexpr TFloat tolerance = std::numeric_limits::epsilon(); + + template + static void test_fft(const std::vector& real_src, const std::vector& dst) { + // check given data size + ASSERT_EQ(real_src.size(), N); + ASSERT_EQ(dst.size(), N); + + // convert real-number source into complex-number source + std::vector src(real_src.size()); + std::generate(src.begin(), src.end(), [data = real_src.begin()]() mutable -> TComplex { return TComplex(*data++); }); + + // create FFT instance and compute data + TFft fft; + fft.compute(src.data()); + + // check result with tolerance + for (TIndex i = 0u; i < src.size(); ++i) { + EXPECT_NEAR(src[i].real(), dst[i].real(), TOLERANCE); + EXPECT_NEAR(src[i].imag(), dst[i].imag(), TOLERANCE); + } + } + + TEST(CartonFft, Test1) { + std::vector src = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + std::vector expected = {{+3.6000e+01f, +0.0000e+00f}, + {-4.0000e+00f, +9.6569e+00f}, + {-4.0000e+00f, +4.0000e+00f}, + {-4.0000e+00f, +1.6569e+00f}, + {-4.0000e+00f, +0.0000e+00f}, + {-4.0000e+00f, -1.6569e+00f}, + {-4.0000e+00f, -4.0000e+00f}, + {-4.0000e+00f, -9.6569e+00f}}; + test_fft<8>(src, expected); + } + + TEST(CartonFft, Test2) { + std::vector src = {6.0f, 1.0f, 7.0f, 2.0f, 7.0f, 4.0f, 8.0f, 7.0f}; + std::vector expected = {{+4.2000e+01f, +0.0000e+00f}, + {+4.1421e-01f, +6.6569e+00f}, + {-2.0000e+00f, +4.0000e+00f}, + {-2.4142e+00f, +4.6569e+00f}, + {+1.4000e+01f, +0.0000e+00f}, + {-2.4142e+00f, -4.6569e+00f}, + {-2.0000e+00f, -4.0000e+00f}, + {+4.1421e-01f, -6.6569e+00f}}; + test_fft<8>(src, expected); + } + + TEST(CartonFft, Test3) { + std::vector src = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector expected = {{+1.0000e+01f, +0.0000e+00f}, + {-2.0000e+00f, +2.0000e+00f}, + {-2.0000e+00f, +0.0000e+00f}, + {-2.0000e+00f, -2.0000e+00f}}; + test_fft<4>(src, expected); + } + + TEST(CartonFft, Test4) { + std::vector src = {6.0f, 1.0f, 7.0f, 2.0f}; + std::vector expected = {{+1.6000e+01f, +0.0000e+00f}, + {-1.0000e+00f, +1.0000e+00f}, + {+1.0000e+01f, +0.0000e+00f}, + {-1.0000e+00f, -1.0000e+00f}}; + test_fft<4>(src, expected); + } + + TEST(CartonFft, Test5) { + std::vector src = {4.0f, 4.0f, 4.0f, 4.0f}; + std::vector expected = {{+1.6000e+01f, +0.0000e+00f}, + {+0.0000e+00f, +0.0000e+00f}, + {+0.0000e+00f, +0.0000e+00f}, + {+0.0000e+00f, +0.0000e+00f}}; + test_fft<4>(src, expected); + } + + TEST(CartonFft, Test6) { + std::vector src = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; + std::vector expected = {{+1.3600e+02f, +0.0000e+00f}, + {-8.0000e+00f, +4.0219e+01f}, + {-8.0000e+00f, +1.9314e+01f}, + {-8.0000e+00f, +1.1973e+01f}, + {-8.0000e+00f, +8.0000e+00f}, + {-8.0000e+00f, +5.3454e+00f}, + {-8.0000e+00f, +3.3137e+00f}, + {-8.0000e+00f, +1.5913e+00f}, + {-8.0000e+00f, +0.0000e+00f}, + {-8.0000e+00f, -1.5913e+00f}, + {-8.0000e+00f, -3.3137e+00f}, + {-8.0000e+00f, -5.3454e+00f}, + {-8.0000e+00f, -8.0000e+00f}, + {-8.0000e+00f, -1.1973e+01f}, + {-8.0000e+00f, -1.9314e+01f}, + {-8.0000e+00f, -4.0219e+01f}}; + test_fft<16>(src, expected); + } + +} // namespace yycctest::carton::fft