first commit
This commit is contained in:
17
gpu_utils.py
Normal file
17
gpu_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def print_gpu_availability():
|
||||
"""打印PyTorch的GPU可用性"""
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU可用:{torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
print("GPU不可用")
|
||||
|
||||
|
||||
def get_gpu_device() -> torch.device:
|
||||
"""获取PyTorch的GPU设备"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
else:
|
||||
raise Exception("找不到CUDA!")
|
||||
Reference in New Issue
Block a user