fix exp2 pytorch rewrite fatal train issue
This commit is contained in:
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
|
||||
import torch.nn.functional as F
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
import gpu_utils
|
||||
import pytorch_gpu_utils
|
||||
|
||||
|
||||
class CurveKind(IntEnum):
|
||||
@@ -56,7 +56,7 @@ class Net(torch.nn.Module):
|
||||
|
||||
|
||||
def main():
|
||||
device = gpu_utils.get_gpu_device()
|
||||
device = pytorch_gpu_utils.get_gpu_device()
|
||||
test_data = DataSource(device, CurveKind.Polynomials)
|
||||
net = Net(n_feature=1, n_hidden=20, n_output=1).to(device)
|
||||
|
||||
@@ -86,5 +86,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpu_utils.print_gpu_availability()
|
||||
pytorch_gpu_utils.print_gpu_availability()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user