finish exp2
This commit is contained in:
190
exp2/modified/sketchpad.py
Normal file
190
exp2/modified/sketchpad.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import typing
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from predict import PredictResult, Predictor
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import gpu_utils
|
||||
|
||||
|
||||
class SketchpadApp:
|
||||
IMAGE_HW: typing.ClassVar[int] = 28
|
||||
PIXEL_HW: typing.ClassVar[int] = 15
|
||||
|
||||
def __init__(self, root: tk.Tk, predictor: Predictor):
|
||||
self.root = root
|
||||
self.root.title("看图说数")
|
||||
|
||||
# 创建画板框架
|
||||
canvas_frame = tk.Frame(root)
|
||||
canvas_frame.pack(pady=10)
|
||||
# 创建图像大小的画板
|
||||
self.canvas_pixel_count = SketchpadApp.IMAGE_HW
|
||||
self.canvas_pixel_size = SketchpadApp.PIXEL_HW # 每个像素的大小
|
||||
canvas_hw = self.canvas_pixel_count * self.canvas_pixel_size
|
||||
self.canvas_width = canvas_hw
|
||||
self.canvas_height = canvas_hw
|
||||
self.canvas = tk.Canvas(
|
||||
canvas_frame,
|
||||
width=self.canvas_width,
|
||||
height=self.canvas_height,
|
||||
bg='black'
|
||||
)
|
||||
self.canvas.pack()
|
||||
# 存储画板状态。False表示没有画(黑色),True表示画了(白色)。
|
||||
self.canvas_data = [[False for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)]
|
||||
# 绑定鼠标事件
|
||||
self.canvas.bind("<B1-Motion>", self.paint)
|
||||
self.canvas.bind("<Button-1>", self.paint)
|
||||
# 绘制初始网格
|
||||
self.draw_grid()
|
||||
|
||||
# 创建表格框架
|
||||
table_frame = tk.Frame(root)
|
||||
table_frame.pack(pady=10)
|
||||
# 表头数据
|
||||
header_words = ("猜测的数字", ) + tuple(f'为{i}的概率' for i in range(10))
|
||||
# 创建表头
|
||||
for col, header in enumerate(header_words):
|
||||
header_label = tk.Label(
|
||||
table_frame,
|
||||
text=header,
|
||||
relief="solid",
|
||||
borderwidth=1,
|
||||
width=12,
|
||||
height=2,
|
||||
bg="lightblue"
|
||||
)
|
||||
header_label.grid(row=0, column=col, sticky="nsew")
|
||||
# 创建第二行(显示数值的行)
|
||||
self.value_labels = []
|
||||
for col in range(len(header_words)):
|
||||
value_label = tk.Label(
|
||||
table_frame,
|
||||
text="0.00", # 默认显示0.00
|
||||
relief="solid",
|
||||
borderwidth=1,
|
||||
width=12,
|
||||
height=2,
|
||||
bg="white"
|
||||
)
|
||||
value_label.grid(row=1, column=col, sticky="nsew")
|
||||
self.value_labels.append(value_label)
|
||||
# 设置第一列的特殊样式(猜测的数字)
|
||||
self.value_labels[0].config(text="N/A", bg="lightyellow")
|
||||
# 清空样式
|
||||
self.clear_table()
|
||||
|
||||
# 创建按钮框架
|
||||
button_frame = tk.Frame(root)
|
||||
button_frame.pack(pady=10)
|
||||
# 执行按钮
|
||||
execute_button = tk.Button(
|
||||
button_frame,
|
||||
text="执行",
|
||||
command=self.execute,
|
||||
bg='lightgreen',
|
||||
width=10
|
||||
)
|
||||
execute_button.pack(side=tk.LEFT, padx=5)
|
||||
# 重置按钮
|
||||
reset_button = tk.Button(
|
||||
button_frame,
|
||||
text="重置",
|
||||
command=self.reset,
|
||||
bg='lightcoral',
|
||||
width=10
|
||||
)
|
||||
reset_button.pack(side=tk.LEFT, padx=5)
|
||||
# 设置用于执行的predictor
|
||||
self.predictor = predictor
|
||||
|
||||
# region: 画板部分
|
||||
|
||||
canvas: tk.Canvas
|
||||
canvas_data: list[list[bool]]
|
||||
canvas_width: int
|
||||
canvas_height: int
|
||||
|
||||
def draw_grid(self):
|
||||
"""绘制网格线"""
|
||||
for i in range(self.canvas_pixel_count + 1):
|
||||
# 垂直线
|
||||
self.canvas.create_line(
|
||||
i * self.canvas_pixel_size, 0,
|
||||
i * self.canvas_pixel_size, self.canvas_height,
|
||||
fill='lightgray'
|
||||
)
|
||||
# 水平线
|
||||
self.canvas.create_line(
|
||||
0, i * self.canvas_pixel_size,
|
||||
self.canvas_width, i * self.canvas_pixel_size,
|
||||
fill='lightgray'
|
||||
)
|
||||
|
||||
def paint(self, event):
|
||||
"""处理鼠标绘制事件"""
|
||||
# 计算点击的网格坐标
|
||||
col = event.x // self.canvas_pixel_size
|
||||
row = event.y // self.canvas_pixel_size
|
||||
|
||||
# 确保坐标在有效范围内
|
||||
if 0 <= col < self.canvas_pixel_count and 0 <= row < self.canvas_pixel_count:
|
||||
# 更新网格状态
|
||||
if self.canvas_data[row][col] != True:
|
||||
self.canvas_data[row][col] = True
|
||||
|
||||
# 绘制黑色矩形
|
||||
x1 = col * self.canvas_pixel_size
|
||||
y1 = row * self.canvas_pixel_size
|
||||
x2 = x1 + self.canvas_pixel_size
|
||||
y2 = y1 + self.canvas_pixel_size
|
||||
|
||||
self.canvas.create_rectangle(x1, y1, x2, y2, fill='white', outline='')
|
||||
|
||||
# endregion
|
||||
|
||||
# region: 表格部分
|
||||
|
||||
value_labels: list[tk.Label]
|
||||
|
||||
def show_in_table(self, result: PredictResult):
|
||||
self.value_labels[0].config(text=str(result.chosen_number()))
|
||||
|
||||
number_possibilities = result.number_possibilities()
|
||||
for index, label in enumerate(self.value_labels[1:]):
|
||||
label.config(text=f'{number_possibilities[index]:.4f}')
|
||||
|
||||
def clear_table(self):
|
||||
for label in self.value_labels:
|
||||
label.config(text='N/A')
|
||||
|
||||
# endregion
|
||||
|
||||
# region: 按钮部分
|
||||
|
||||
predictor: Predictor
|
||||
|
||||
def execute(self):
|
||||
"""执行按钮功能 - 将画板数据传递给后端"""
|
||||
prediction = self.predictor.predict(self.canvas_data)
|
||||
self.show_in_table(prediction)
|
||||
|
||||
def reset(self):
|
||||
"""重置按钮功能 - 清空画板"""
|
||||
self.canvas.delete("all")
|
||||
self.canvas_data = [[0 for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)]
|
||||
self.draw_grid()
|
||||
self.clear_table()
|
||||
|
||||
# endregion
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpu_utils.print_gpu_availability()
|
||||
predictor = Predictor()
|
||||
|
||||
root = tk.Tk()
|
||||
app = SketchpadApp(root, predictor)
|
||||
root.mainloop()
|
||||
Reference in New Issue
Block a user