update for change of exp2 and add exp3
This commit is contained in:
114
exp2/modified/mnist.py
Normal file
114
exp2/modified/mnist.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from pathlib import Path
|
||||
import numpy
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import v2 as tvtrans
|
||||
from torchvision import datasets
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CNN(torch.nn.Module):
|
||||
"""卷积神经网络模型"""
|
||||
|
||||
def __init__(self):
|
||||
super(CNN, self).__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
|
||||
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2))
|
||||
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
|
||||
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2))
|
||||
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
|
||||
self.flatten = torch.nn.Flatten()
|
||||
# 28x28过第一轮卷积后变为26x26,过第一轮池化后变为13x13
|
||||
# 过第二轮卷积后变为11x11,过第二轮池化后变为5x5
|
||||
# 过第三轮卷积后变为3x3。
|
||||
# 最后一轮卷积核个数为64。
|
||||
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
|
||||
torch.nn.init.xavier_normal_(self.fc1.weight)
|
||||
torch.nn.init.zeros_(self.fc1.bias)
|
||||
self.fc2 = torch.nn.Linear(64, 10)
|
||||
torch.nn.init.xavier_normal_(self.fc2.weight)
|
||||
torch.nn.init.zeros_(self.fc2.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = self.pool1(x)
|
||||
x = F.relu(self.conv2(x))
|
||||
x = self.pool2(x)
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.flatten(x)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.softmax(x, dim=1)
|
||||
|
||||
|
||||
class MnistDataset(Dataset):
|
||||
"""用于加载Mnist的自定义数据集"""
|
||||
|
||||
shape: int
|
||||
transform: tvtrans.Transform
|
||||
images_data: numpy.ndarray
|
||||
labels_data: torch.Tensor
|
||||
|
||||
def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform):
|
||||
images_len: int = images.shape[0]
|
||||
labels_len: int = labels.shape[0]
|
||||
assert (images_len == labels_len)
|
||||
self.shape = images_len
|
||||
|
||||
self.images_data = images
|
||||
self.labels_data = torch.from_numpy(labels)
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.transform(self.images_data[index]), self.labels_data[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.shape
|
||||
|
||||
|
||||
class MnistDataSource:
|
||||
"""用于读取MNIST数据的数据读取器"""
|
||||
|
||||
train_loader: DataLoader
|
||||
test_loader: DataLoader
|
||||
|
||||
def __init__(self, batch_size: int):
|
||||
dataset_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
|
||||
dataset = numpy.load(dataset_path)
|
||||
|
||||
# 所有图片均为黑底白字
|
||||
# 6万张训练图片:60000x28x28。标签只有第一维。
|
||||
train_images: numpy.ndarray = dataset['x_train']
|
||||
train_labels: numpy.ndarray = dataset['y_train']
|
||||
# 1万张测试图片:10000x28x28。标签只有第一维。
|
||||
test_images: numpy.ndarray = dataset['x_test']
|
||||
test_labels: numpy.ndarray = dataset['y_test']
|
||||
|
||||
# 定义数据转换器
|
||||
trans = tvtrans.Compose([
|
||||
# 从uint8转换为float32并自动归一化到0-1区间
|
||||
# tvtrans.ToTensor(),
|
||||
tvtrans.ToImage(),
|
||||
tvtrans.ToDtype(torch.float32, scale=True),
|
||||
|
||||
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
|
||||
#tvtrans.Lambda(lambda x: x.unsqueeze(-1))
|
||||
|
||||
# 这个特定的标准化参数 (0.1307, 0.3081) 是 MNIST 数据集的标准化参数,这些数值是MNIST训练集的全局均值和标准差。
|
||||
# 这种标准化有助于模型训练时的数值稳定性和收敛速度。
|
||||
#tvtrans.Normalize((0.1307,), (0.3081,)),
|
||||
])
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = MnistDataset(train_images, train_labels, transform=trans)
|
||||
test_dataset = MnistDataset(test_images, test_labels, transform=trans)
|
||||
|
||||
|
||||
# 赋值到自身
|
||||
self.train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
self.test_loader = DataLoader(dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
@@ -1,7 +1,10 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import torch
|
||||
from train import CNN
|
||||
import numpy
|
||||
from PIL import Image, ImageFile
|
||||
import matplotlib.pyplot as plt
|
||||
from mnist import CNN
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import gpu_utils
|
||||
@@ -36,7 +39,7 @@ class Predictor:
|
||||
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
||||
self.cnn.load_state_dict(torch.load(file_path))
|
||||
|
||||
def predict(self, image: list[list[bool]]) -> PredictResult:
|
||||
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
|
||||
input = torch.Tensor(image).float().to(self.device)
|
||||
assert(input.dim() == 2)
|
||||
assert(input.size(0) == 28)
|
||||
@@ -51,4 +54,42 @@ class Predictor:
|
||||
with torch.no_grad():
|
||||
output = self.cnn(input)
|
||||
return PredictResult(output)
|
||||
|
||||
|
||||
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
|
||||
# 确保图像为灰度图像,然后转换为numpy数组。
|
||||
# 注意这里的numpy数组是只读的,所以要先拷贝一份
|
||||
grayscale_image = image.convert('L')
|
||||
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
|
||||
# 转换到Tensor,设置dtype并传到GPU上
|
||||
data = torch.from_numpy(numpy_data).float().to(self.device)
|
||||
# 归一化到255,又因为图像输入是白底黑字,需要做转换。
|
||||
data.div_(255.0).sub_(1).mul_(-1)
|
||||
|
||||
# 同理,挤出维度并预测
|
||||
input = data.unsqueeze(0).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
output = self.cnn(input)
|
||||
return PredictResult(output)
|
||||
|
||||
def main():
|
||||
predictor = Predictor()
|
||||
|
||||
# 遍历测试目录中的所有图片,并处理。
|
||||
test_dir = Path(__file__).resolve().parent.parent / 'test_images'
|
||||
for image_path in test_dir.glob('*.png'):
|
||||
if image_path.is_file():
|
||||
print(f'Predicting {image_path} ...')
|
||||
image = Image.open(image_path)
|
||||
rv = predictor.predict_image(image)
|
||||
|
||||
print(f'Predict digit: {rv.chosen_number()}')
|
||||
plt.figure(f'Image - {image_path}')
|
||||
plt.imshow(image)
|
||||
plt.axis('on')
|
||||
plt.title(f'Predict digit: {rv.chosen_number()}')
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ class SketchpadApp:
|
||||
|
||||
def execute(self):
|
||||
"""执行按钮功能 - 将画板数据传递给后端"""
|
||||
prediction = self.predictor.predict(self.canvas_data)
|
||||
prediction = self.predictor.predict_sketchpad(self.canvas_data)
|
||||
self.show_in_table(prediction)
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -1,145 +1,57 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import typing
|
||||
import numpy
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import v2 as tvtrans
|
||||
import matplotlib.pyplot as plt
|
||||
import torch.nn.functional as F
|
||||
import torchinfo
|
||||
import ignite.engine
|
||||
import ignite.metrics
|
||||
from ignite.engine import Engine, Events
|
||||
from ignite.handlers.tqdm_logger import ProgressBar
|
||||
from mnist import CNN, MnistDataSource
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import gpu_utils
|
||||
|
||||
|
||||
class CNN(torch.nn.Module):
|
||||
"""卷积神经网络模型"""
|
||||
|
||||
def __init__(self):
|
||||
super(CNN, self).__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
|
||||
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2))
|
||||
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
|
||||
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2))
|
||||
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
|
||||
self.flatten = torch.nn.Flatten()
|
||||
# 28x28过第一轮卷积后变为26x26,过第一轮池化后变为13x13
|
||||
# 过第二轮卷积后变为11x11,过第二轮池化后变为5x5
|
||||
# 过第三轮卷积后变为3x3。
|
||||
# 最后一轮卷积核个数为64。
|
||||
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
|
||||
self.fc2 = torch.nn.Linear(64, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = self.pool1(x)
|
||||
x = F.relu(self.conv2(x))
|
||||
x = self.pool2(x)
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.flatten(x)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.softmax(x, dim=1)
|
||||
|
||||
|
||||
class MnistDataset(Dataset):
|
||||
"""用于加载Mnist的自定义数据集"""
|
||||
|
||||
shape: int
|
||||
transform: tvtrans.Transform
|
||||
images_data: numpy.ndarray
|
||||
labels_data: torch.Tensor
|
||||
|
||||
def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform):
|
||||
images_len: int = images.size(0)
|
||||
labels_len: int = labels.size(0)
|
||||
assert (images_len == labels_len)
|
||||
self.shape = images_len
|
||||
|
||||
self.images_data = images
|
||||
self.labels_data = torch.from_numpy(labels)
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.transform(self.images_data[index]), self.labels_data[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.shape
|
||||
|
||||
|
||||
class DataSource:
|
||||
"""用于读取MNIST数据的数据读取器"""
|
||||
|
||||
train_data: DataLoader
|
||||
test_data: DataLoader
|
||||
|
||||
def __init__(self, batch_size: int):
|
||||
datasets_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
|
||||
datasets = numpy.load(datasets_path)
|
||||
|
||||
# 所有图片均为黑底白字
|
||||
# 6万张训练图片:60000x28x28。标签只有第一维。
|
||||
train_images = datasets['x_train']
|
||||
train_labels = datasets['y_train']
|
||||
# 1万张测试图片:10000x28x28。标签只有第一维。
|
||||
test_images = datasets['x_test']
|
||||
test_labels = datasets['y_test']
|
||||
|
||||
# 定义数据转换器
|
||||
trans = tvtrans.Compose([
|
||||
# 从uint8转换为float32并自动归一化到0-1区间
|
||||
# tvtrans.ToTensor(),
|
||||
tvtrans.ToImage(),
|
||||
tvtrans.ToDtype(torch.float32, scale=True),
|
||||
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
|
||||
#tvtrans.Lambda(lambda x: x.unsqueeze(-1))
|
||||
])
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = MnistDataset(train_images,
|
||||
train_labels,
|
||||
transform=trans)
|
||||
test_dataset = MnistDataset(test_images, test_labels, transform=trans)
|
||||
|
||||
# 赋值到自身
|
||||
self.train_data = DataLoader(dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
self.test_data = DataLoader(dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
|
||||
class Trainer:
|
||||
N_EPOCH: typing.ClassVar[int] = 5
|
||||
N_BATCH_SIZE: typing.ClassVar[int] = 1000
|
||||
|
||||
device: torch.device
|
||||
data_source: DataSource
|
||||
cnn: CNN
|
||||
data_source: MnistDataSource
|
||||
model: CNN
|
||||
|
||||
def __init__(self):
|
||||
self.device = gpu_utils.get_gpu_device()
|
||||
self.data_source = DataSource(Trainer.N_BATCH_SIZE)
|
||||
self.cnn = CNN().to(self.device)
|
||||
self.data_source = MnistDataSource(Trainer.N_BATCH_SIZE)
|
||||
self.model = CNN().to(self.device)
|
||||
# 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。
|
||||
torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28))
|
||||
|
||||
def train(self):
|
||||
optimizer = torch.optim.Adam(self.cnn.parameters())
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
|
||||
# optimizer = torch.optim.AdamW(
|
||||
# self.model.parameters(),
|
||||
# lr=0.001, # 两者默认学习率都是 0.001
|
||||
# betas=(0.9, 0.999), # 两者默认值相同
|
||||
# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
|
||||
# weight_decay=0.0, # 两者默认都是 0
|
||||
# amsgrad=False # 两者默认都是 False
|
||||
# )
|
||||
loss_func = torch.nn.CrossEntropyLoss()
|
||||
|
||||
for epoch in range(Trainer.N_EPOCH):
|
||||
self.cnn.train()
|
||||
self.model.train()
|
||||
|
||||
batch_images: torch.Tensor
|
||||
batch_labels: torch.Tensor
|
||||
for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_data):
|
||||
for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_loader):
|
||||
gpu_images = batch_images.to(self.device)
|
||||
gpu_labels = batch_labels.to(self.device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
prediction: torch.Tensor = self.cnn(gpu_images)
|
||||
prediction: torch.Tensor = self.model(gpu_images)
|
||||
loss: torch.Tensor = loss_func(prediction, gpu_labels)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
@@ -151,20 +63,20 @@ class Trainer:
|
||||
file_dir_path = Path(__file__).resolve().parent.parent / 'models'
|
||||
file_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
file_path = file_dir_path / 'cnn.pth'
|
||||
torch.save(self.cnn.state_dict(), file_path)
|
||||
torch.save(self.model.state_dict(), file_path)
|
||||
print(f'模型已保存至:{file_path}')
|
||||
|
||||
def test(self):
|
||||
self.cnn.eval()
|
||||
self.model.eval()
|
||||
correct_sum = 0
|
||||
total_sum = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_images, batch_labels in self.data_source.test_data:
|
||||
for batch_images, batch_labels in self.data_source.test_loader:
|
||||
gpu_images = batch_images.to(self.device)
|
||||
gpu_labels = batch_labels.to(self.device)
|
||||
|
||||
possibilities: torch.Tensor = self.cnn(gpu_images)
|
||||
possibilities: torch.Tensor = self.model(gpu_images)
|
||||
# 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比
|
||||
# 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。
|
||||
_, prediction = possibilities.max(1)
|
||||
@@ -175,13 +87,105 @@ class Trainer:
|
||||
test_acc = 100. * correct_sum / total_sum
|
||||
print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片")
|
||||
|
||||
|
||||
def main():
|
||||
trainer = Trainer()
|
||||
trainer.train()
|
||||
trainer.save()
|
||||
trainer.test()
|
||||
|
||||
# N_EPOCH: int = 5
|
||||
# N_BATCH_SIZE: int = 1000
|
||||
# N_LOG_INTERVAL: int = 10
|
||||
|
||||
# class Trainer:
|
||||
|
||||
# device: torch.device
|
||||
# data_source: MnistDataSource
|
||||
# model: CNN
|
||||
|
||||
# trainer: Engine
|
||||
# train_evaluator: Engine
|
||||
# test_evaluator: Engine
|
||||
|
||||
# def __init__(self):
|
||||
# self.device = gpu_utils.get_gpu_device()
|
||||
# self.model = CNN().to(self.device)
|
||||
# self.data_source = MnistDataSource(batch_size=N_BATCH_SIZE)
|
||||
# # 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。
|
||||
# torchinfo.summary(self.model, (N_BATCH_SIZE, 1, 28, 28))
|
||||
|
||||
# #optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
|
||||
# optimizer = torch.optim.AdamW(
|
||||
# self.model.parameters(),
|
||||
# lr=0.001, # 两者默认学习率都是 0.001
|
||||
# betas=(0.9, 0.999), # 两者默认值相同
|
||||
# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
|
||||
# weight_decay=0.0, # 两者默认都是 0
|
||||
# amsgrad=False # 两者默认都是 False
|
||||
# )
|
||||
# criterion = torch.nn.CrossEntropyLoss()
|
||||
# self.trainer = ignite.engine.create_supervised_trainer(
|
||||
# self.model, optimizer, criterion, self.device
|
||||
# )
|
||||
|
||||
# eval_metrics = {
|
||||
# "accuracy": ignite.metrics.Accuracy(device=self.device),
|
||||
# "loss": ignite.metrics.Loss(criterion, device=self.device)
|
||||
# }
|
||||
# self.train_evaluator = ignite.engine.create_supervised_evaluator(
|
||||
# self.model, metrics=eval_metrics, device=self.device)
|
||||
# self.test_evaluator = ignite.engine.create_supervised_evaluator(
|
||||
# self.model, metrics=eval_metrics, device=self.device)
|
||||
|
||||
# self.trainer.add_event_handler(
|
||||
# Events.ITERATION_COMPLETED(every=N_LOG_INTERVAL),
|
||||
# lambda engine: self.log_intrain_loss(engine)
|
||||
# )
|
||||
# self.trainer.add_event_handler(
|
||||
# Events.EPOCH_COMPLETED,
|
||||
# lambda trainer: self.log_train_results(trainer)
|
||||
# )
|
||||
# self.trainer.add_event_handler(
|
||||
# Events.COMPLETED,
|
||||
# lambda _: self.log_test_results()
|
||||
# )
|
||||
# self.trainer.add_event_handler(
|
||||
# Events.COMPLETED,
|
||||
# lambda _: self.save_model()
|
||||
# )
|
||||
|
||||
# progressbar = ProgressBar()
|
||||
# progressbar.attach(self.trainer)
|
||||
|
||||
# def run(self):
|
||||
# self.trainer.run(self.data_source.train_loader, max_epochs=N_EPOCH)
|
||||
|
||||
# def log_intrain_loss(self, engine: Engine):
|
||||
# print(f"Epoch: {engine.state.epoch}, Loss: {engine.state.output:.4f}\r", end="")
|
||||
|
||||
# def log_train_results(self, trainer: Engine):
|
||||
# self.train_evaluator.run(self.data_source.train_loader)
|
||||
# metrics = self.train_evaluator.state.metrics
|
||||
# print()
|
||||
# print(f"Training - Epoch: {trainer.state.epoch}, Avg Accuracy: {metrics['accuracy']:.4f}, Avg Loss: {metrics['loss']:.4f}")
|
||||
|
||||
# def log_test_results(self):
|
||||
# self.test_evaluator.run(self.data_source.test_loader)
|
||||
# metrics = self.test_evaluator.state.metrics
|
||||
# print(f"Test - Avg Accuracy: {metrics['accuracy']:.4f} Avg Loss: {metrics['loss']:.4f}")
|
||||
|
||||
# def save_model(self):
|
||||
# file_dir_path = Path(__file__).resolve().parent.parent / 'models'
|
||||
# file_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
# file_path = file_dir_path / 'cnn.pth'
|
||||
# torch.save(self.model.state_dict(), file_path)
|
||||
# print(f'Model was saved into: {file_path}')
|
||||
|
||||
# def main():
|
||||
# trainer = Trainer()
|
||||
# trainer.run()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpu_utils.print_gpu_availability()
|
||||
|
||||
2
exp2/test_images/.gitignore
vendored
Normal file
2
exp2/test_images/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Ignore all test images
|
||||
*.png
|
||||
2
exp3/datasets/.gitignore
vendored
Normal file
2
exp3/datasets/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Ignore datasets and processed datasets
|
||||
*.txt
|
||||
2
exp3/models/.gitignore
vendored
Normal file
2
exp3/models/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Ignore every saved model files
|
||||
*.pth
|
||||
171
exp3/modified/dataset.py
Normal file
171
exp3/modified/dataset.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from pathlib import Path
|
||||
import typing
|
||||
import pickle
|
||||
from collections import Counter
|
||||
import numpy
|
||||
|
||||
class Tokenizer:
|
||||
"""分词器"""
|
||||
|
||||
token_dict: dict[str, int]
|
||||
"""词->编号的映射"""
|
||||
token_dict_rev: dict[int, str]
|
||||
"""编号->词的映射"""
|
||||
vocab_size: int
|
||||
"""词汇表大小"""
|
||||
|
||||
def __init__(self, token_dict: dict[str, int]):
|
||||
self.token_dict = token_dict
|
||||
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
|
||||
self.vocab_size = len(self.token_dict)
|
||||
|
||||
def id_to_token(self, token_id: int) -> str:
|
||||
"""
|
||||
给定一个编号,查找词汇表中对应的词。
|
||||
|
||||
:param token_id: 带查找词的编号
|
||||
:return: 编号对应的词
|
||||
"""
|
||||
return self.token_dict_rev[token_id]
|
||||
|
||||
def token_to_id(self, token: str):
|
||||
"""
|
||||
给定一个词,查找它在词汇表中的编号。
|
||||
未找到则返回低频词[UNK]的编号。
|
||||
|
||||
:param token: 带查找编号的词
|
||||
:return: 词的编号
|
||||
"""
|
||||
return self.token_dict.get(token, self.token_dict['[UNK]'])
|
||||
|
||||
def encode(self, tokens: str) -> list[int]:
|
||||
"""
|
||||
给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
|
||||
|
||||
:param tokens: 待编码字符串
|
||||
:return: 编号序列
|
||||
"""
|
||||
# 加上开始标记
|
||||
token_ids: list[int] = [self.token_to_id('[CLS]'), ]
|
||||
# 加入字符串编号序列
|
||||
for token in tokens:
|
||||
token_ids.append(self.token_to_id(token))
|
||||
# 加上结束标记
|
||||
token_ids.append(self.token_to_id('[SEP]'))
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids: typing.Iterable[int]) -> str:
|
||||
"""
|
||||
给定一个编号序列,将它解码成字符串
|
||||
|
||||
:param token_ids: 待解码的编号序列
|
||||
:return: 解码出的字符串
|
||||
"""
|
||||
# 起止标记字符特殊处理
|
||||
spec_tokens = {'[CLS]', '[SEP]'}
|
||||
# 保存解码出的字符的list
|
||||
tokens: list[str] = []
|
||||
for token_id in token_ids:
|
||||
token = self.id_to_token(token_id)
|
||||
if token in spec_tokens:
|
||||
continue
|
||||
tokens.append(token)
|
||||
# 拼接字符串
|
||||
return ''.join(tokens)
|
||||
|
||||
|
||||
class PoetryDataset:
|
||||
"""古诗词数据集加载器"""
|
||||
|
||||
BAD_WORDS: typing.ClassVar[list[str]] = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
||||
"""禁用词,包含如下字符的唐诗将被忽略"""
|
||||
MAX_SEG_LEN: typing.ClassVar[int] = 64
|
||||
"""句子最大长度"""
|
||||
MIN_WORD_FREQ: typing.ClassVar[int] = 8
|
||||
"""最小词频"""
|
||||
|
||||
tokenizer: Tokenizer
|
||||
"""分词器"""
|
||||
poetry: list[str]
|
||||
"""古诗词数据集"""
|
||||
|
||||
def __init__(self, force_reclean: bool = False):
|
||||
# 加载古诗,然后统计词频构建分词器
|
||||
self.poetry = self.load_poetry(force_reclean)
|
||||
self.tokenizer = self.build_tokenizer(self.poetry)
|
||||
|
||||
def load_poetry(self, force_reclean: bool = False) -> list[str]:
|
||||
"""加载古诗词数据集"""
|
||||
if force_reclean or (not self.get_clean_dataset_path().is_file()):
|
||||
return self.load_poetry_from_dirty()
|
||||
else:
|
||||
return self.load_poetry_from_clean()
|
||||
|
||||
def load_poetry_from_clean(self) -> list[str]:
|
||||
"""直接读取清洗后的数据"""
|
||||
with open(self.get_clean_dataset_path(), 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def load_poetry_from_dirty(self) -> list[str]:
|
||||
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
|
||||
with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
# 将冒号统一成相同格式
|
||||
lines = [line.replace(':', ':') for line in lines]
|
||||
|
||||
# 数据集列表
|
||||
poetry: list[str] = []
|
||||
# 逐行处理读取到的数据
|
||||
for line in lines:
|
||||
# 删除空白字符
|
||||
line = line.strip()
|
||||
# 有且只能有一个冒号用来分割标题
|
||||
if line.count(':') != 1: continue
|
||||
# 获取后半部分
|
||||
_, last_part = line.split(':')
|
||||
# 长度不能超过最大长度
|
||||
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
|
||||
continue
|
||||
# 不能包含禁止词
|
||||
for bad_word in PoetryDataset.BAD_WORDS:
|
||||
if bad_word in last_part:
|
||||
break
|
||||
else:
|
||||
# 如果循环正常结束,就表明没有bad words,推入数据列表
|
||||
poetry.append(last_part)
|
||||
|
||||
# 数据清理完毕
|
||||
# 写入干净数据
|
||||
with open(self.get_clean_dataset_path(), 'wb') as f:
|
||||
pickle.dump(poetry, f)
|
||||
# 返回结果
|
||||
return poetry
|
||||
|
||||
def get_clean_dataset_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||||
|
||||
def get_dirty_dataset_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||||
|
||||
def build_tokenizer(self, poetry: list[str]) -> Tokenizer:
|
||||
"""统计词频,并构建分词器"""
|
||||
# 统计词频
|
||||
counter: Counter[str] = Counter()
|
||||
for line in poetry:
|
||||
counter.update(line)
|
||||
# 过滤掉低频词
|
||||
tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ)
|
||||
# 按词频排序
|
||||
tokens = sorted(tokens, key=lambda x: -x[1])
|
||||
# 去掉词频,只保留词列表
|
||||
tokens = list(token for token, _ in tokens)
|
||||
|
||||
# 将特殊词和数据集中的词拼接起来
|
||||
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
|
||||
# 创建词典 token->id映射关系
|
||||
token_id_dict = dict(zip(tokens, range(len(tokens))))
|
||||
# 使用新词典重新建立分词器
|
||||
tokenizer = Tokenizer(token_id_dict)
|
||||
# 直接返回,此处无需混洗数据
|
||||
return tokenizer
|
||||
|
||||
12
exp3/modified/settings.py
Normal file
12
exp3/modified/settings.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
BATCH_SIZE: int = 16
|
||||
"""训练的batch size"""
|
||||
|
||||
def get_saved_model_path() -> Path:
|
||||
"""
|
||||
获取训练完毕的模型进行保存的路径。
|
||||
|
||||
:return: 模型参数保存的路径。
|
||||
"""
|
||||
return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
||||
0
exp3/modified/utils.py
Normal file
0
exp3/modified/utils.py
Normal file
199
exp3/source/dataset.py
Normal file
199
exp3/source/dataset.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#ANLI College of Artificial Intelligence
|
||||
|
||||
from collections import Counter
|
||||
import math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import settings
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
分词器
|
||||
"""
|
||||
|
||||
def __init__(self, token_dict):
|
||||
# 词->编号的映射
|
||||
self.token_dict = token_dict
|
||||
# 编号->词的映射
|
||||
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
|
||||
# 词汇表大小
|
||||
self.vocab_size = len(self.token_dict)
|
||||
|
||||
def id_to_token(self, token_id):
|
||||
"""
|
||||
给定一个编号,查找词汇表中对应的词
|
||||
:param token_id: 带查找词的编号
|
||||
:return: 编号对应的词
|
||||
"""
|
||||
return self.token_dict_rev[token_id]
|
||||
|
||||
def token_to_id(self, token):
|
||||
"""
|
||||
给定一个词,查找它在词汇表中的编号
|
||||
未找到则返回低频词[UNK]的编号
|
||||
:param token: 带查找编号的词
|
||||
:return: 词的编号
|
||||
"""
|
||||
return self.token_dict.get(token, self.token_dict['[UNK]'])
|
||||
|
||||
def encode(self, tokens):
|
||||
"""
|
||||
给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
|
||||
:param tokens: 待编码字符串
|
||||
:return: 编号序列
|
||||
"""
|
||||
# 加上开始标记
|
||||
token_ids = [self.token_to_id('[CLS]'), ]
|
||||
# 加入字符串编号序列
|
||||
for token in tokens:
|
||||
token_ids.append(self.token_to_id(token))
|
||||
# 加上结束标记
|
||||
token_ids.append(self.token_to_id('[SEP]'))
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids):
|
||||
"""
|
||||
给定一个编号序列,将它解码成字符串
|
||||
:param token_ids: 待解码的编号序列
|
||||
:return: 解码出的字符串
|
||||
"""
|
||||
# 起止标记字符特殊处理
|
||||
spec_tokens = {'[CLS]', '[SEP]'}
|
||||
# 保存解码出的字符的list
|
||||
tokens = []
|
||||
for token_id in token_ids:
|
||||
token = self.id_to_token(token_id)
|
||||
if token in spec_tokens:
|
||||
continue
|
||||
tokens.append(token)
|
||||
# 拼接字符串
|
||||
return ''.join(tokens)
|
||||
|
||||
|
||||
# 禁用词
|
||||
disallowed_words = settings.DISALLOWED_WORDS
|
||||
# 句子最大长度
|
||||
max_len = settings.MAX_LEN
|
||||
# 最小词频
|
||||
min_word_frequency = settings.MIN_WORD_FREQUENCY
|
||||
# mini batch 大小
|
||||
batch_size = settings.BATCH_SIZE
|
||||
|
||||
# 加载数据集
|
||||
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
# 将冒号统一成相同格式
|
||||
lines = [line.replace(':', ':') for line in lines]
|
||||
# 数据集列表
|
||||
poetry = []
|
||||
# 逐行处理读取到的数据
|
||||
for line in lines:
|
||||
# 有且只能有一个冒号用来分割标题
|
||||
if line.count(':') != 1:
|
||||
continue
|
||||
# 后半部分不能包含禁止词
|
||||
__, last_part = line.split(':')
|
||||
ignore_flag = False
|
||||
for dis_word in disallowed_words:
|
||||
if dis_word in last_part:
|
||||
ignore_flag = True
|
||||
break
|
||||
if ignore_flag:
|
||||
continue
|
||||
# 长度不能超过最大长度
|
||||
if len(last_part) > max_len - 2:
|
||||
continue
|
||||
poetry.append(last_part.replace('\n', ''))
|
||||
|
||||
# 统计词频
|
||||
counter = Counter()
|
||||
for line in poetry:
|
||||
counter.update(line)
|
||||
# 过滤掉低频词
|
||||
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
|
||||
# 按词频排序
|
||||
_tokens = sorted(_tokens, key=lambda x: -x[1])
|
||||
# 去掉词频,只保留词列表
|
||||
_tokens = [token for token, count in _tokens]
|
||||
|
||||
# 将特殊词和数据集中的词拼接起来
|
||||
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
|
||||
# 创建词典 token->id映射关系
|
||||
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
|
||||
# 使用新词典重新建立分词器
|
||||
tokenizer = Tokenizer(token_id_dict)
|
||||
# 混洗数据
|
||||
np.random.shuffle(poetry)
|
||||
|
||||
|
||||
class PoetryDataGenerator:
|
||||
"""
|
||||
古诗数据集生成器
|
||||
"""
|
||||
|
||||
def __init__(self, data, random=False):
|
||||
# 数据集
|
||||
self.data = data
|
||||
# batch size
|
||||
self.batch_size = batch_size
|
||||
# 每个epoch迭代的步数
|
||||
self.steps = int(math.floor(len(self.data) / self.batch_size))
|
||||
# 每个epoch开始时是否随机混洗
|
||||
self.random = random
|
||||
|
||||
def sequence_padding(self, data, length=None, padding=None):
|
||||
"""
|
||||
将给定数据填充到相同长度
|
||||
:param data: 待填充数据
|
||||
:param length: 填充后的长度,不传递此参数则使用data中的最大长度
|
||||
:param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
|
||||
:return: 填充后的数据
|
||||
"""
|
||||
# 计算填充长度
|
||||
if length is None:
|
||||
length = max(map(len, data))
|
||||
# 计算填充数据
|
||||
if padding is None:
|
||||
padding = tokenizer.token_to_id('[PAD]')
|
||||
# 开始填充
|
||||
outputs = []
|
||||
for line in data:
|
||||
padding_length = length - len(line)
|
||||
# 不足就进行填充
|
||||
if padding_length > 0:
|
||||
outputs.append(np.concatenate([line, [padding] * padding_length]))
|
||||
# 超过就进行截断
|
||||
else:
|
||||
outputs.append(line[:length])
|
||||
return np.array(outputs)
|
||||
|
||||
def __len__(self):
|
||||
return self.steps
|
||||
|
||||
def __iter__(self):
|
||||
total = len(self.data)
|
||||
# 是否随机混洗
|
||||
if self.random:
|
||||
np.random.shuffle(self.data)
|
||||
# 迭代一个epoch,每次yield一个batch
|
||||
for start in range(0, total, self.batch_size):
|
||||
end = min(start + self.batch_size, total)
|
||||
batch_data = []
|
||||
# 逐一对古诗进行编码
|
||||
for single_data in self.data[start:end]:
|
||||
batch_data.append(tokenizer.encode(single_data))
|
||||
# 填充为相同长度
|
||||
batch_data = self.sequence_padding(batch_data)
|
||||
# yield x,y
|
||||
yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
|
||||
del batch_data
|
||||
|
||||
def for_fit(self):
|
||||
"""
|
||||
创建一个生成器,用于训练
|
||||
"""
|
||||
# 死循环,当数据训练一个epoch之后,重新迭代数据
|
||||
while True:
|
||||
# 委托生成器
|
||||
yield from self.__iter__()
|
||||
16
exp3/source/eval.py
Normal file
16
exp3/source/eval.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#ANLI College of Artificial Intelligence
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from dataset import tokenizer
|
||||
import settings
|
||||
import utils
|
||||
|
||||
# 加载训练好的模型
|
||||
model = tf.keras.models.load_model(settings.BEST_MODEL_PATH)
|
||||
# 随机生成一首诗
|
||||
print(utils.generate_random_poetry(tokenizer, model))
|
||||
# 给出部分信息的情况下,随机生成剩余部分
|
||||
print(utils.generate_random_poetry(tokenizer, model, s='床前明月光,'))
|
||||
# 生成藏头诗
|
||||
print(utils.generate_acrostic(tokenizer, model, head='好好学习天天向上'))
|
||||
19
exp3/source/settings.py
Normal file
19
exp3/source/settings.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#ANLI College of Artificial Intelligence
|
||||
|
||||
|
||||
# 禁用词,包含如下字符的唐诗将被忽略
|
||||
DISALLOWED_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
||||
# 句子最大长度
|
||||
MAX_LEN = 64
|
||||
# 最小词频
|
||||
MIN_WORD_FREQUENCY = 8
|
||||
# 训练的batch size
|
||||
BATCH_SIZE = 16
|
||||
# 数据集路径
|
||||
DATASET_PATH = './poetry.txt'
|
||||
# 每个epoch训练完成后,随机生成SHOW_NUM首古诗作为展示
|
||||
SHOW_NUM = 5
|
||||
# 共训练多少个epoch
|
||||
TRAIN_EPOCHS = 10
|
||||
# 最佳权重保存路径
|
||||
BEST_MODEL_PATH = './best_model.h5'
|
||||
35
exp3/source/train.py
Normal file
35
exp3/source/train.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import tensorflow as tf
|
||||
from dataset import PoetryDataGenerator, tokenizer, poetry
|
||||
import settings
|
||||
import utils
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Input((None,)),
|
||||
tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
|
||||
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
|
||||
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
|
||||
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
|
||||
|
||||
])
|
||||
model.summary()
|
||||
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)
|
||||
|
||||
class Evaluate(tf.keras.callbacks.Callback):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lowest = 1e10
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if logs['loss'] <= self.lowest:
|
||||
self.lowest = logs['loss']
|
||||
model.save(settings.BEST_MODEL_PATH)
|
||||
print()
|
||||
for i in range(settings.SHOW_NUM):
|
||||
print(utils.generate_random_poetry(tokenizer, model))
|
||||
|
||||
data_generator = PoetryDataGenerator(poetry, random=False)
|
||||
model.fit_generator(data_generator.for_fit(),
|
||||
steps_per_epoch=data_generator.steps,
|
||||
epochs=settings.TRAIN_EPOCHS,
|
||||
callbacks=[Evaluate()])
|
||||
86
exp3/source/utils.py
Normal file
86
exp3/source/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
|
||||
import numpy as np
|
||||
import settings
|
||||
|
||||
|
||||
def generate_random_poetry(tokenizer, model, s=''):
|
||||
"""
|
||||
随机生成一首诗
|
||||
:param tokenizer: 分词器
|
||||
:param model: 用于生成古诗的模型
|
||||
:param s: 用于生成古诗的起始字符串,默认为空串
|
||||
:return: 一个字符串,表示一首古诗
|
||||
"""
|
||||
# 将初始字符串转成token
|
||||
token_ids = tokenizer.encode(s)
|
||||
# 去掉结束标记[SEP]
|
||||
token_ids = token_ids[:-1]
|
||||
while len(token_ids) < settings.MAX_LEN:
|
||||
# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
|
||||
output = model(np.array([token_ids, ], dtype=np.int32))
|
||||
_probas = output.numpy()[0, -1, 3:]
|
||||
del output
|
||||
# print(_probas)
|
||||
# 按照出现概率,对所有token倒序排列
|
||||
p_args = _probas.argsort()[::-1][:100]
|
||||
# 排列后的概率顺序
|
||||
p = _probas[p_args]
|
||||
# 先对概率归一
|
||||
p = p / sum(p)
|
||||
# 再按照预测出的概率,随机选择一个词作为预测结果
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
target = p_args[target_index] + 3
|
||||
# 保存
|
||||
token_ids.append(target)
|
||||
if target == 3:
|
||||
break
|
||||
return tokenizer.decode(token_ids)
|
||||
|
||||
|
||||
def generate_acrostic(tokenizer, model, head):
|
||||
"""
|
||||
随机生成一首藏头诗
|
||||
:param tokenizer: 分词器
|
||||
:param model: 用于生成古诗的模型
|
||||
:param head: 藏头诗的头
|
||||
:return: 一个字符串,表示一首古诗
|
||||
"""
|
||||
# 使用空串初始化token_ids,加入[CLS]
|
||||
token_ids = tokenizer.encode('')
|
||||
token_ids = token_ids[:-1]
|
||||
# 标点符号,这里简单的只把逗号和句号作为标点
|
||||
punctuations = [',', '。']
|
||||
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
|
||||
# 缓存生成的诗的list
|
||||
poetry = []
|
||||
# 对于藏头诗中的每一个字,都生成一个短句
|
||||
for ch in head:
|
||||
# 先记录下这个字
|
||||
poetry.append(ch)
|
||||
# 将藏头诗的字符转成token id
|
||||
token_id = tokenizer.token_to_id(ch)
|
||||
# 加入到列表中去
|
||||
token_ids.append(token_id)
|
||||
# 开始生成一个短句
|
||||
while True:
|
||||
# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
|
||||
output = model(np.array([token_ids, ], dtype=np.int32))
|
||||
_probas = output.numpy()[0, -1, 3:]
|
||||
del output
|
||||
# 按照出现概率,对所有token倒序排列
|
||||
p_args = _probas.argsort()[::-1][:100]
|
||||
# 排列后的概率顺序
|
||||
p = _probas[p_args]
|
||||
# 先对概率归一
|
||||
p = p / sum(p)
|
||||
# 再按照预测出的概率,随机选择一个词作为预测结果
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
target = p_args[target_index] + 3
|
||||
# 保存
|
||||
token_ids.append(target)
|
||||
# 只有不是特殊字符时,才保存到poetry里面去
|
||||
if target > 3:
|
||||
poetry.append(tokenizer.id_to_token(target))
|
||||
if target in punctuation_ids:
|
||||
break
|
||||
return ''.join(poetry)
|
||||
@@ -8,8 +8,11 @@ dependencies = [
|
||||
"datasets>=4.3.0",
|
||||
"matplotlib>=3.10.7",
|
||||
"numpy>=2.3.4",
|
||||
"torch>=2.9.0",
|
||||
"torchvision>=0.24.0",
|
||||
"pillow>=12.0.0",
|
||||
"pytorch-ignite>=0.5.3",
|
||||
"torch>=2.9.0",
|
||||
"torchinfo>=1.8.0",
|
||||
"torchvision>=0.24.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
29
uv.lock
generated
29
uv.lock
generated
@@ -393,8 +393,11 @@ dependencies = [
|
||||
{ name = "datasets" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pytorch-ignite" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "torchinfo" },
|
||||
{ name = "torchvision", version = "0.24.0", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" },
|
||||
{ name = "torchvision", version = "0.24.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torchvision", version = "0.24.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
@@ -405,8 +408,11 @@ requires-dist = [
|
||||
{ name = "datasets", specifier = ">=4.3.0" },
|
||||
{ name = "matplotlib", specifier = ">=3.10.7" },
|
||||
{ name = "numpy", specifier = ">=2.3.4" },
|
||||
{ name = "pillow", specifier = ">=12.0.0" },
|
||||
{ name = "pytorch-ignite", specifier = ">=0.5.3" },
|
||||
{ name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.9.0" },
|
||||
{ name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu126" },
|
||||
{ name = "torchinfo", specifier = ">=1.8.0" },
|
||||
{ name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.24.0" },
|
||||
{ name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=0.24.0", index = "https://download.pytorch.org/whl/cu126" },
|
||||
]
|
||||
@@ -1639,6 +1645,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytorch-ignite"
|
||||
version = "0.5.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/e5/7fe880b24de30b4eadc8d997ea8d3c4a8f507b1a34dcdced08d88f665ee3/pytorch_ignite-0.5.3.tar.gz", hash = "sha256:75c645f02fea66cc80c1998ade3f8402e0e6b6d73f3f4ad727c171f6e93874f4", size = 7506607, upload-time = "2025-10-16T00:42:05.142Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/ea/f6d5ee7433a5a1c1e4746e2a4e9a222eab545fdbe04b66754ffdab479ee8/pytorch_ignite-0.5.3-py3-none-any.whl", hash = "sha256:4ced7539c690a3b6f3116da7878389954dff787c33669f83b38221f3746bc63e", size = 343802, upload-time = "2025-10-16T00:41:55.738Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2025.2"
|
||||
@@ -1848,6 +1868,15 @@ wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cu126/torch-2.9.0%2Bcu126-cp314-cp314t-win_amd64.whl", hash = "sha256:d8fdfc45ba30cb5c23971b35ae72c6fe246596022b574bd37dd0c775958f70b1" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torchinfo"
|
||||
version = "1.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/53/d9/2b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996/torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9", size = 25880, upload-time = "2023-05-14T19:23:26.377Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46", size = 23377, upload-time = "2023-05-14T19:23:24.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torchvision"
|
||||
version = "0.24.0"
|
||||
|
||||
Reference in New Issue
Block a user