191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
from pathlib import Path
|
||
import sys
|
||
import typing
|
||
import tkinter as tk
|
||
from tkinter import messagebox
|
||
from predict import PredictResult, Predictor
|
||
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||
import gpu_utils
|
||
|
||
|
||
class SketchpadApp:
|
||
IMAGE_HW: typing.ClassVar[int] = 28
|
||
PIXEL_HW: typing.ClassVar[int] = 15
|
||
|
||
def __init__(self, root: tk.Tk, predictor: Predictor):
|
||
self.root = root
|
||
self.root.title("看图说数")
|
||
|
||
# 创建画板框架
|
||
canvas_frame = tk.Frame(root)
|
||
canvas_frame.pack(pady=10)
|
||
# 创建图像大小的画板
|
||
self.canvas_pixel_count = SketchpadApp.IMAGE_HW
|
||
self.canvas_pixel_size = SketchpadApp.PIXEL_HW # 每个像素的大小
|
||
canvas_hw = self.canvas_pixel_count * self.canvas_pixel_size
|
||
self.canvas_width = canvas_hw
|
||
self.canvas_height = canvas_hw
|
||
self.canvas = tk.Canvas(
|
||
canvas_frame,
|
||
width=self.canvas_width,
|
||
height=self.canvas_height,
|
||
bg='black'
|
||
)
|
||
self.canvas.pack()
|
||
# 存储画板状态。False表示没有画(黑色),True表示画了(白色)。
|
||
self.canvas_data = [[False for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)]
|
||
# 绑定鼠标事件
|
||
self.canvas.bind("<B1-Motion>", self.paint)
|
||
self.canvas.bind("<Button-1>", self.paint)
|
||
# 绘制初始网格
|
||
self.draw_grid()
|
||
|
||
# 创建表格框架
|
||
table_frame = tk.Frame(root)
|
||
table_frame.pack(pady=10)
|
||
# 表头数据
|
||
header_words = ("猜测的数字", ) + tuple(f'为{i}的概率' for i in range(10))
|
||
# 创建表头
|
||
for col, header in enumerate(header_words):
|
||
header_label = tk.Label(
|
||
table_frame,
|
||
text=header,
|
||
relief="solid",
|
||
borderwidth=1,
|
||
width=12,
|
||
height=2,
|
||
bg="lightblue"
|
||
)
|
||
header_label.grid(row=0, column=col, sticky="nsew")
|
||
# 创建第二行(显示数值的行)
|
||
self.value_labels = []
|
||
for col in range(len(header_words)):
|
||
value_label = tk.Label(
|
||
table_frame,
|
||
text="0.00", # 默认显示0.00
|
||
relief="solid",
|
||
borderwidth=1,
|
||
width=12,
|
||
height=2,
|
||
bg="white"
|
||
)
|
||
value_label.grid(row=1, column=col, sticky="nsew")
|
||
self.value_labels.append(value_label)
|
||
# 设置第一列的特殊样式(猜测的数字)
|
||
self.value_labels[0].config(text="N/A", bg="lightyellow")
|
||
# 清空样式
|
||
self.clear_table()
|
||
|
||
# 创建按钮框架
|
||
button_frame = tk.Frame(root)
|
||
button_frame.pack(pady=10)
|
||
# 执行按钮
|
||
execute_button = tk.Button(
|
||
button_frame,
|
||
text="执行",
|
||
command=self.execute,
|
||
bg='lightgreen',
|
||
width=10
|
||
)
|
||
execute_button.pack(side=tk.LEFT, padx=5)
|
||
# 重置按钮
|
||
reset_button = tk.Button(
|
||
button_frame,
|
||
text="重置",
|
||
command=self.reset,
|
||
bg='lightcoral',
|
||
width=10
|
||
)
|
||
reset_button.pack(side=tk.LEFT, padx=5)
|
||
# 设置用于执行的predictor
|
||
self.predictor = predictor
|
||
|
||
# region: 画板部分
|
||
|
||
canvas: tk.Canvas
|
||
canvas_data: list[list[bool]]
|
||
canvas_width: int
|
||
canvas_height: int
|
||
|
||
def draw_grid(self):
|
||
"""绘制网格线"""
|
||
for i in range(self.canvas_pixel_count + 1):
|
||
# 垂直线
|
||
self.canvas.create_line(
|
||
i * self.canvas_pixel_size, 0,
|
||
i * self.canvas_pixel_size, self.canvas_height,
|
||
fill='lightgray'
|
||
)
|
||
# 水平线
|
||
self.canvas.create_line(
|
||
0, i * self.canvas_pixel_size,
|
||
self.canvas_width, i * self.canvas_pixel_size,
|
||
fill='lightgray'
|
||
)
|
||
|
||
def paint(self, event):
|
||
"""处理鼠标绘制事件"""
|
||
# 计算点击的网格坐标
|
||
col = event.x // self.canvas_pixel_size
|
||
row = event.y // self.canvas_pixel_size
|
||
|
||
# 确保坐标在有效范围内
|
||
if 0 <= col < self.canvas_pixel_count and 0 <= row < self.canvas_pixel_count:
|
||
# 更新网格状态
|
||
if self.canvas_data[row][col] != True:
|
||
self.canvas_data[row][col] = True
|
||
|
||
# 绘制黑色矩形
|
||
x1 = col * self.canvas_pixel_size
|
||
y1 = row * self.canvas_pixel_size
|
||
x2 = x1 + self.canvas_pixel_size
|
||
y2 = y1 + self.canvas_pixel_size
|
||
|
||
self.canvas.create_rectangle(x1, y1, x2, y2, fill='white', outline='')
|
||
|
||
# endregion
|
||
|
||
# region: 表格部分
|
||
|
||
value_labels: list[tk.Label]
|
||
|
||
def show_in_table(self, result: PredictResult):
|
||
self.value_labels[0].config(text=str(result.chosen_number()))
|
||
|
||
number_possibilities = result.number_possibilities()
|
||
for index, label in enumerate(self.value_labels[1:]):
|
||
label.config(text=f'{number_possibilities[index]:.4f}')
|
||
|
||
def clear_table(self):
|
||
for label in self.value_labels:
|
||
label.config(text='N/A')
|
||
|
||
# endregion
|
||
|
||
# region: 按钮部分
|
||
|
||
predictor: Predictor
|
||
|
||
def execute(self):
|
||
"""执行按钮功能 - 将画板数据传递给后端"""
|
||
prediction = self.predictor.predict_sketchpad(self.canvas_data)
|
||
self.show_in_table(prediction)
|
||
|
||
def reset(self):
|
||
"""重置按钮功能 - 清空画板"""
|
||
self.canvas.delete("all")
|
||
self.canvas_data = [[0 for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)]
|
||
self.draw_grid()
|
||
self.clear_table()
|
||
|
||
# endregion
|
||
|
||
if __name__ == "__main__":
|
||
gpu_utils.print_gpu_availability()
|
||
predictor = Predictor()
|
||
|
||
root = tk.Tk()
|
||
app = SketchpadApp(root, predictor)
|
||
root.mainloop()
|