1
0

use ignite for exp2

This commit is contained in:
2025-12-02 23:07:27 +08:00
parent 43b807679f
commit 65c56e938c
15 changed files with 246 additions and 794 deletions

View File

@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
import torch.nn.functional as F
sys.path.append(str(Path(__file__).resolve().parent.parent))
import pytorch_gpu_utils
import gpu_utils
class CurveKind(IntEnum):
@@ -56,7 +56,7 @@ class Net(torch.nn.Module):
def main():
device = pytorch_gpu_utils.get_gpu_device()
device = gpu_utils.get_gpu_device()
test_data = DataSource(device, CurveKind.Polynomials)
net = Net(n_feature=1, n_hidden=20, n_output=1).to(device)
@@ -86,5 +86,5 @@ def main():
if __name__ == "__main__":
pytorch_gpu_utils.print_gpu_availability()
gpu_utils.print_gpu_availability()
main()