1
0

fix exp2 pytorch rewrite fatal train issue

This commit is contained in:
2025-11-30 22:01:56 +08:00
parent 48fcdfcc80
commit 43b807679f
13 changed files with 738 additions and 112 deletions

View File

@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
import torch.nn.functional as F
sys.path.append(str(Path(__file__).resolve().parent.parent))
import gpu_utils
import pytorch_gpu_utils
class CurveKind(IntEnum):
@@ -56,7 +56,7 @@ class Net(torch.nn.Module):
def main():
device = gpu_utils.get_gpu_device()
device = pytorch_gpu_utils.get_gpu_device()
test_data = DataSource(device, CurveKind.Polynomials)
net = Net(n_feature=1, n_hidden=20, n_output=1).to(device)
@@ -86,5 +86,5 @@ def main():
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
pytorch_gpu_utils.print_gpu_availability()
main()