18 lines
420 B
Python
18 lines
420 B
Python
import torch
|
||
|
||
|
||
def print_gpu_availability():
|
||
"""打印PyTorch的GPU可用性"""
|
||
if torch.cuda.is_available():
|
||
print(f"GPU可用:{torch.cuda.get_device_name(0)}")
|
||
else:
|
||
print("GPU不可用")
|
||
|
||
|
||
def get_gpu_device() -> torch.device:
|
||
"""获取PyTorch的GPU设备"""
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
else:
|
||
raise Exception("找不到CUDA!")
|