1
0
Files
ai-school/exp2/modified/sketchpad.py

191 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import typing
import tkinter as tk
from predict import PredictResult, Predictor
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import pytorch_gpu_utils
class SketchpadApp:
IMAGE_HW: typing.ClassVar[int] = 28
PIXEL_HW: typing.ClassVar[int] = 15
def __init__(self, root: tk.Tk, predictor: Predictor):
self.root = root
self.root.title("看图说数")
# 创建画板框架
canvas_frame = tk.Frame(root)
canvas_frame.pack(pady=10)
# 创建图像大小的画板
self.canvas_pixel_count = SketchpadApp.IMAGE_HW
self.canvas_pixel_size = SketchpadApp.PIXEL_HW # 每个像素的大小
canvas_hw = self.canvas_pixel_count * self.canvas_pixel_size
self.canvas_width = canvas_hw
self.canvas_height = canvas_hw
self.canvas = tk.Canvas(
canvas_frame,
width=self.canvas_width,
height=self.canvas_height,
bg='black'
)
self.canvas.pack()
# 存储画板状态。False表示没有画黑色True表示画了白色
self.canvas_data = [[False for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)]
# 绑定鼠标事件
self.canvas.bind("<B1-Motion>", self.paint)
self.canvas.bind("<Button-1>", self.paint)
# 绘制初始网格
self.draw_grid()
# 创建表格框架
table_frame = tk.Frame(root)
table_frame.pack(pady=10)
# 表头数据
header_words = ("猜测的数字", ) + tuple(f'{i}的概率' for i in range(10))
# 创建表头
for col, header in enumerate(header_words):
header_label = tk.Label(
table_frame,
text=header,
relief="solid",
borderwidth=1,
width=12,
height=2,
bg="lightblue"
)
header_label.grid(row=0, column=col, sticky="nsew")
# 创建第二行(显示数值的行)
self.value_labels = []
for col in range(len(header_words)):
value_label = tk.Label(
table_frame,
text="0.00", # 默认显示0.00
relief="solid",
borderwidth=1,
width=12,
height=2,
bg="white"
)
value_label.grid(row=1, column=col, sticky="nsew")
self.value_labels.append(value_label)
# 设置第一列的特殊样式(猜测的数字)
self.value_labels[0].config(text="N/A", bg="lightyellow")
# 清空样式
self.clear_table()
# 创建按钮框架
button_frame = tk.Frame(root)
button_frame.pack(pady=10)
# 执行按钮
execute_button = tk.Button(
button_frame,
text="执行",
command=self.execute,
bg='lightgreen',
width=10
)
execute_button.pack(side=tk.LEFT, padx=5)
# 重置按钮
reset_button = tk.Button(
button_frame,
text="重置",
command=self.reset,
bg='lightcoral',
width=10
)
reset_button.pack(side=tk.LEFT, padx=5)
# 设置用于执行的predictor
self.predictor = predictor
# region: 画板部分
canvas: tk.Canvas
canvas_data: list[list[bool]]
canvas_width: int
canvas_height: int
def draw_grid(self):
"""绘制网格线"""
for i in range(self.canvas_pixel_count + 1):
# 垂直线
self.canvas.create_line(
i * self.canvas_pixel_size, 0,
i * self.canvas_pixel_size, self.canvas_height,
fill='lightgray'
)
# 水平线
self.canvas.create_line(
0, i * self.canvas_pixel_size,
self.canvas_width, i * self.canvas_pixel_size,
fill='lightgray'
)
def paint(self, event):
"""处理鼠标绘制事件"""
# 计算点击的网格坐标
col = event.x // self.canvas_pixel_size
row = event.y // self.canvas_pixel_size
# 确保坐标在有效范围内
if 0 <= col < self.canvas_pixel_count and 0 <= row < self.canvas_pixel_count:
# 更新网格状态
if self.canvas_data[row][col] != True:
self.canvas_data[row][col] = True
# 绘制黑色矩形
x1 = col * self.canvas_pixel_size
y1 = row * self.canvas_pixel_size
x2 = x1 + self.canvas_pixel_size
y2 = y1 + self.canvas_pixel_size
self.canvas.create_rectangle(x1, y1, x2, y2, fill='white', outline='')
# endregion
# region: 表格部分
value_labels: list[tk.Label]
def show_in_table(self, result: PredictResult):
self.value_labels[0].config(text=str(result.chosen_number()))
number_possibilities = result.number_possibilities()
for index, label in enumerate(self.value_labels[1:]):
label.config(text=f'{number_possibilities[index]:.4f}')
def clear_table(self):
for label in self.value_labels:
label.config(text='N/A')
# endregion
# region: 按钮部分
predictor: Predictor
def execute(self):
"""执行按钮功能 - 将画板数据传递给后端"""
prediction = self.predictor.predict_sketchpad(self.canvas_data)
self.show_in_table(prediction)
def reset(self):
"""重置按钮功能 - 清空画板"""
self.canvas.delete("all")
self.canvas_data = [[0 for _ in range(self.canvas_pixel_count)] for _ in range(self.canvas_pixel_count)]
self.draw_grid()
self.clear_table()
# endregion
if __name__ == "__main__":
pytorch_gpu_utils.print_gpu_availability()
predictor = Predictor()
root = tk.Tk()
app = SketchpadApp(root, predictor)
root.mainloop()