from pathlib import Path import sys import typing import tkinter as tk from tkinter import messagebox from predict import PredictResult, Predictor sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils class SketchpadApp: IMAGE_HW: typing.ClassVar[int] = 28 PIXEL_HW: typing.ClassVar[int] = 15 def __init__(self, root: tk.Tk, predictor: Predictor): self.root = root self.root.title("看图说数") # 创建画板框架 canvas_frame = tk.Frame(root) canvas_frame.pack(pady=10) # 创建图像大小的画板 self.canvas_pixel_count = SketchpadApp.IMAGE_HW self.canvas_pixel_size = SketchpadApp.PIXEL_HW # 每个像素的大小 canvas_hw = self.canvas_pixel_count * self.canvas_pixel_size self.canvas_width = canvas_hw self.canvas_height = canvas_hw self.canvas = tk.Canvas( canvas_frame, width=self.canvas_width, height=self.canvas_height, bg='black' ) self.canvas.pack() # 存储画板状态。False表示没有画(黑色),True表示画了(白色)。 self.canvas_data = [[False for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)] # 绑定鼠标事件 self.canvas.bind("", self.paint) self.canvas.bind("", self.paint) # 绘制初始网格 self.draw_grid() # 创建表格框架 table_frame = tk.Frame(root) table_frame.pack(pady=10) # 表头数据 header_words = ("猜测的数字", ) + tuple(f'为{i}的概率' for i in range(10)) # 创建表头 for col, header in enumerate(header_words): header_label = tk.Label( table_frame, text=header, relief="solid", borderwidth=1, width=12, height=2, bg="lightblue" ) header_label.grid(row=0, column=col, sticky="nsew") # 创建第二行(显示数值的行) self.value_labels = [] for col in range(len(header_words)): value_label = tk.Label( table_frame, text="0.00", # 默认显示0.00 relief="solid", borderwidth=1, width=12, height=2, bg="white" ) value_label.grid(row=1, column=col, sticky="nsew") self.value_labels.append(value_label) # 设置第一列的特殊样式(猜测的数字) self.value_labels[0].config(text="N/A", bg="lightyellow") # 清空样式 self.clear_table() # 创建按钮框架 button_frame = tk.Frame(root) button_frame.pack(pady=10) # 执行按钮 execute_button = tk.Button( button_frame, text="执行", command=self.execute, bg='lightgreen', width=10 ) execute_button.pack(side=tk.LEFT, padx=5) # 重置按钮 reset_button = tk.Button( button_frame, text="重置", command=self.reset, bg='lightcoral', width=10 ) reset_button.pack(side=tk.LEFT, padx=5) # 设置用于执行的predictor self.predictor = predictor # region: 画板部分 canvas: tk.Canvas canvas_data: list[list[bool]] canvas_width: int canvas_height: int def draw_grid(self): """绘制网格线""" for i in range(self.canvas_pixel_count + 1): # 垂直线 self.canvas.create_line( i * self.canvas_pixel_size, 0, i * self.canvas_pixel_size, self.canvas_height, fill='lightgray' ) # 水平线 self.canvas.create_line( 0, i * self.canvas_pixel_size, self.canvas_width, i * self.canvas_pixel_size, fill='lightgray' ) def paint(self, event): """处理鼠标绘制事件""" # 计算点击的网格坐标 col = event.x // self.canvas_pixel_size row = event.y // self.canvas_pixel_size # 确保坐标在有效范围内 if 0 <= col < self.canvas_pixel_count and 0 <= row < self.canvas_pixel_count: # 更新网格状态 if self.canvas_data[row][col] != True: self.canvas_data[row][col] = True # 绘制黑色矩形 x1 = col * self.canvas_pixel_size y1 = row * self.canvas_pixel_size x2 = x1 + self.canvas_pixel_size y2 = y1 + self.canvas_pixel_size self.canvas.create_rectangle(x1, y1, x2, y2, fill='white', outline='') # endregion # region: 表格部分 value_labels: list[tk.Label] def show_in_table(self, result: PredictResult): self.value_labels[0].config(text=str(result.chosen_number())) number_possibilities = result.number_possibilities() for index, label in enumerate(self.value_labels[1:]): label.config(text=f'{number_possibilities[index]:.4f}') def clear_table(self): for label in self.value_labels: label.config(text='N/A') # endregion # region: 按钮部分 predictor: Predictor def execute(self): """执行按钮功能 - 将画板数据传递给后端""" prediction = self.predictor.predict_sketchpad(self.canvas_data) self.show_in_table(prediction) def reset(self): """重置按钮功能 - 清空画板""" self.canvas.delete("all") self.canvas_data = [[0 for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)] self.draw_grid() self.clear_table() # endregion if __name__ == "__main__": gpu_utils.print_gpu_availability() predictor = Predictor() root = tk.Tk() app = SketchpadApp(root, predictor) root.mainloop()