From 55ed60c643ebcf5fb8891d8eee0cd63c6a83a390 Mon Sep 17 00:00:00 2001 From: yyc12345 Date: Thu, 8 Jan 2026 20:25:33 +0800 Subject: [PATCH] fix bug --- .../Plugins/Engine/DirectX11Engine/main.cpp | 36 +++++++++---------- BasaltPresenter/Presenter/cmd_client.hpp | 2 +- BasaltPresenter/Presenter/dll_loader.cpp | 2 +- BasaltPresenter/Presenter/dll_loader.hpp | 10 +++--- BasaltPresenter/Presenter/main.cpp | 23 ++++++++---- BasaltTrainer/cmd_server.py | 30 +++++----------- BasaltTrainer/pyproject.toml | 2 +- 7 files changed, 50 insertions(+), 55 deletions(-) diff --git a/BasaltPresenter/Plugins/Engine/DirectX11Engine/main.cpp b/BasaltPresenter/Plugins/Engine/DirectX11Engine/main.cpp index 4ad42e5..fde2a53 100644 --- a/BasaltPresenter/Plugins/Engine/DirectX11Engine/main.cpp +++ b/BasaltPresenter/Plugins/Engine/DirectX11Engine/main.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -28,10 +28,9 @@ static LRESULT CALLBACK WndProc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lPara } // Create a render window for DirectX -static HWND CreateRenderWindow(std::uint32_t width, std::uint32_t height, const std::wstring_view& title) { +static HWND CreateRenderWindow(std::uint32_t width, std::uint32_t height) { static bool g_CLSREG = false; constexpr wchar_t class_name[] = L"DirectXRenderWindowClass"; - std::wstring c_title(title); if (!g_CLSREG) { WNDCLASSEXW wc = {0}; @@ -55,7 +54,7 @@ static HWND CreateRenderWindow(std::uint32_t width, std::uint32_t height, const HWND hwnd = CreateWindowExW(0, class_name, - c_title.c_str(), + L"DirectXRenderWindow", WS_OVERLAPPEDWINDOW ^ WS_THICKFRAME ^ WS_MAXIMIZEBOX, CW_USEDEFAULT, CW_USEDEFAULT, @@ -161,8 +160,9 @@ const char* g_PS = R"( } )"; -using ::Basalt::Shared::Kernel::EngineConfig; -using ::Basalt::Shared::Kernel::IEngine; +namespace engine = ::basalt::shared::engine; +using engine::EngineConfig; +using engine::IEngine; class DirectX11Engine : public IEngine { public: @@ -197,7 +197,7 @@ public: if (this->config.headless) { window = NULL; } else { - window = CreateRenderWindow(this->config.width, this->config.height, this->config.title); + window = CreateRenderWindow(this->config.width, this->config.height); ShowWindow(window, SW_SHOW); UpdateWindow(window); } @@ -314,6 +314,17 @@ public: // return; //} + // 设置管线 + context->OMSetRenderTargets(1, rtv.GetAddressOf(), dsv.Get()); + context->OMSetDepthStencilState(depth_state.Get(), 1); + context->RSSetState(rasterizer_state.Get()); + context->VSSetShader(vs.Get(), nullptr, 0); + context->PSSetShader(ps.Get(), nullptr, 0); + context->IASetInputLayout(input_layout.Get()); + UINT stride = sizeof(Vertex), offset = 0; + context->IASetVertexBuffers(0, 1, vertex_buffer.GetAddressOf(), &stride, &offset); + context->IASetPrimitiveTopology(D3D11_PRIMITIVE_TOPOLOGY_TRIANGLELIST); + // 缩放深度数据数组到指定大小 depth_data.resize(this->config.width * this->config.height * sizeof(float)); @@ -331,17 +342,6 @@ public: context->ClearRenderTargetView(rtv.Get(), clear_color); context->ClearDepthStencilView(dsv.Get(), D3D11_CLEAR_DEPTH, 1.0f, 0); - // 设置管线 - context->OMSetRenderTargets(1, rtv.GetAddressOf(), dsv.Get()); - context->OMSetDepthStencilState(depth_state.Get(), 1); - context->RSSetState(rasterizer_state.Get()); - context->VSSetShader(vs.Get(), nullptr, 0); - context->PSSetShader(ps.Get(), nullptr, 0); - context->IASetInputLayout(input_layout.Get()); - UINT stride = sizeof(Vertex), offset = 0; - context->IASetVertexBuffers(0, 1, vertex_buffer.GetAddressOf(), &stride, &offset); - context->IASetPrimitiveTopology(D3D11_PRIMITIVE_TOPOLOGY_TRIANGLELIST); - // 绘制立方体 context->Draw(sizeof(CubeVertices) / sizeof(Vertex), 0); // 自动计算顶点数 diff --git a/BasaltPresenter/Presenter/cmd_client.hpp b/BasaltPresenter/Presenter/cmd_client.hpp index e1e01e7..749a003 100644 --- a/BasaltPresenter/Presenter/cmd_client.hpp +++ b/BasaltPresenter/Presenter/cmd_client.hpp @@ -11,7 +11,7 @@ namespace basalt::presenter::cmd_client { DATA_READY = 0x01, ///< Presenter -> Trainer DATA_RECEIVED = 0x02, ///< Trainer -> Presenter ACTIVELY_STOP = 0x21, ///< Presenter-->Trainer - STOP_REQUEST = 0X71, ///< Presenter<--Trainer + STOP_REQUEST = 0x71, ///< Presenter<--Trainer STOP_RESPONSE = 0x72, ///< Presenter-->Trainer }; diff --git a/BasaltPresenter/Presenter/dll_loader.cpp b/BasaltPresenter/Presenter/dll_loader.cpp index 5c4c9a1..c8dd4d3 100644 --- a/BasaltPresenter/Presenter/dll_loader.cpp +++ b/BasaltPresenter/Presenter/dll_loader.cpp @@ -77,7 +77,7 @@ namespace basalt::presenter::dll_loader { } } - void *DllLoader::GetFunctionPointer(const char *name) { + void *DllLoader::get_function_pointer(const char *name) { if (!m_Handle) throw std::runtime_error("Can not fetch function pointer on not loaded dynamic library."); #if defined(BASALT_OS_WINDOWS) return (void *) GetProcAddress(m_Handle, name); diff --git a/BasaltPresenter/Presenter/dll_loader.hpp b/BasaltPresenter/Presenter/dll_loader.hpp index dd35d6d..0e8881d 100644 --- a/BasaltPresenter/Presenter/dll_loader.hpp +++ b/BasaltPresenter/Presenter/dll_loader.hpp @@ -28,21 +28,21 @@ namespace basalt::presenter::dll_loader { ~DllLoader(); private: - void* GetFunctionPointer(const char* name); + void* get_function_pointer(const char* name); public: template - T* CreateInstance() { + T* create_instance() { using Fct = T* (*) (); constexpr char EXPOSE_FUNC_NAME[] = "BSCreateInstance"; - auto fct = (Fct) GetFunctionPointer(EXPOSE_FUNC_NAME); + auto fct = (Fct) get_function_pointer(EXPOSE_FUNC_NAME); return fct(); } template - void DestroyInstance(T* instance) { + void destroy_instance(T* instance) { using Fct = void (*)(T*); constexpr char EXPOSE_FUNC_NAME[] = "BSDestroyInstance"; - auto fct = (Fct) GetFunctionPointer(EXPOSE_FUNC_NAME); + auto fct = (Fct) get_function_pointer(EXPOSE_FUNC_NAME); fct(instance); } diff --git a/BasaltPresenter/Presenter/main.cpp b/BasaltPresenter/Presenter/main.cpp index 3b947e7..c67becd 100644 --- a/BasaltPresenter/Presenter/main.cpp +++ b/BasaltPresenter/Presenter/main.cpp @@ -1,6 +1,15 @@ #include "dll_loader.hpp" #include "cmd_client.hpp" #include +#include +#include + +namespace engine = ::basalt::shared::engine; +using engine::EngineConfig; +using engine::IEngine; +namespace deliver = ::basalt::shared::deliver; +using deliver::DeliverConfig; +using deliver::IDeliver; namespace dll_loader = ::basalt::presenter::dll_loader; using dll_loader::DllKind; @@ -15,18 +24,18 @@ int main(int argc, char* argv[]) { auto client = CmdClient(); auto payload = client.wait_handshake(); - //auto* engine = engine_dll.CreateInstance(); - //auto* deliver = deliver_dll.CreateInstance(); + auto* engine = engine_dll.create_instance(); + //auto* deliver = deliver_dll.create_instance(); - //Kernel::EngineConfig engine_config{.headless = false, .title = BSTEXT("Fuck You"), .width = payload.width, .height = payload.height}; - //engine->startup(std::move(engine_config)); + EngineConfig engine_config{.headless = false, .width = payload.width, .height = payload.height}; + engine->startup(std::move(engine_config)); while (true) { - auto req_stop = false; //engine->tick(); + auto req_stop = engine->tick(); auto can_stop = client.tick(req_stop); if (can_stop) break; } - //engine->shutdown(); - //engine_dll.DestroyInstance(engine); + engine->shutdown(); + engine_dll.destroy_instance(engine); } diff --git a/BasaltTrainer/cmd_server.py b/BasaltTrainer/cmd_server.py index 61bd318..5da2b9d 100644 --- a/BasaltTrainer/cmd_server.py +++ b/BasaltTrainer/cmd_server.py @@ -92,10 +92,10 @@ class CmdServer: if self.__status != ServerStatus.Running: raise RuntimeError("unexpected server status") - # If there is stop requested, we post it first and return + # If there is stop requested from us, + # we order Presenter exit and enter next step. if request_stop: - self.__wait_stop() - return True + self.__pipe_operator.write(CODE_PACKER.pack(ProtocolCode.STOP_REQUEST)) while True: # Wait for code from Presenter @@ -115,26 +115,12 @@ class CmdServer: case ProtocolCode.ACTIVELY_STOP: # Presenter requested stop. # Agree with it, send code and wait response - self.__wait_stop() - return True - case _: - raise RuntimeError("unexpected protocol code when running") - - def __wait_stop(self) -> None: - # Send stop request code - self.__pipe_operator.write(CODE_PACKER.pack(ProtocolCode.STOP_REQUEST)) - - # Wait stop response code - while True: - # Accept code - code_bytes = self.__pipe_operator.read(CODE_PACKER.size) - (code,) = CODE_PACKER.unpack(code_bytes) - - # Check whether it is stop response - match ProtocolCode(code): + self.__pipe_operator.write( + CODE_PACKER.pack(ProtocolCode.STOP_REQUEST) + ) case ProtocolCode.STOP_RESPONSE: # Set self status and return self.__status = ServerStatus.Stop - return + return True case _: - raise RuntimeError("unexpected protocol code when waiting quit") + raise RuntimeError("unexpected protocol code when running") diff --git a/BasaltTrainer/pyproject.toml b/BasaltTrainer/pyproject.toml index bddcb84..e648160 100644 --- a/BasaltTrainer/pyproject.toml +++ b/BasaltTrainer/pyproject.toml @@ -27,4 +27,4 @@ url = "https://download.pytorch.org/whl/cu126" explicit = true [too.ruff] -line-length = 88 +line-length = 130