18 lines
420 B
Python
18 lines
420 B
Python
|
|
import torch
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_gpu_availability():
|
|||
|
|
"""打印PyTorch的GPU可用性"""
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
print(f"GPU可用:{torch.cuda.get_device_name(0)}")
|
|||
|
|
else:
|
|||
|
|
print("GPU不可用")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_gpu_device() -> torch.device:
|
|||
|
|
"""获取PyTorch的GPU设备"""
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
return torch.device("cuda")
|
|||
|
|
else:
|
|||
|
|
raise Exception("找不到CUDA!")
|