1. PyTorch란?
PyTorch는 Facebook AI Research(FAIR)에서 개발한 Python 기반 오픈소스 딥러닝 프레임워크입니다.
- NumPy와 유사한 텐서 연산 지원
- GPU 가속(CUDA)로 빠른 연산 가능
- 동적 연산 그래프(Dynamic Computation Graph) 사용
- 사용이 직관적이고 디버깅이 쉬움
- 딥러닝, 자연어 처리(NLP), 컴퓨터 비전(CV) 등에 폭넓게 활용됨
2. PyTorch 주요 개념
(1) 텐서(Tensor)
PyTorch에서 가장 기본적인 자료형은 Tensor입니다.
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1, 2], [3, 4]])
c = torch.randn(2, 3, 4)
print(c.shape)
(2) NumPy 변환
import numpy as np
t = torch.tensor([1, 2, 3])
n = t.numpy()
print(n, type(n))
n = np.array([4, 5, 6])
t = torch.from_numpy(n)
print(t, type(t))
(3) GPU 연산
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(3, 3).to(device)
print(x.device)
(4) 텐서 연산
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
print(x + y)
3. PyTorch 모델 만들기
(1) torch.nn을 이용한 모델 정의
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = MyModel()
print(model)
(2) 손실 함수 및 옵티마이저
import torch.optim as optim
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
(3) 학습 과정
for epoch in range(10):
optimizer.zero_grad()
x = torch.randn(5, 10)
y = torch.randn(5, 1)
pred = model(x)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
4. 데이터 로딩 (torch.utils.data)
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 10)
self.labels = torch.randint(0, 2, (100,))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
5. PyTorch 주요 기능 요약
| 기능 | 설명 |
|---|---|
| 텐서 연산 | torch.tensor(), +, -, * 등 다양한 연산 |
| GPU 연산 | .to('cuda')를 이용해 GPU에서 연산 가능 |
| 신경망 정의 | nn.Module을 상속받아 모델 생성 |
| 손실 함수 | nn.MSELoss(), nn.CrossEntropyLoss() 등 다양한 손실 함수 |
| 옵티마이저 | optim.SGD(), optim.Adam() 등 최적화 알고리즘 |
| 역전파 학습 | zero_grad(), loss.backward(), step() 사용 |
| 데이터 로딩 | Dataset, DataLoader를 활용하여 데이터 배치 로딩 |
'[Python] Code' 카테고리의 다른 글
| [Python] 자료구조 (0) | 2025.02.10 |
|---|---|
| [Python] Numpy 소개 (0) | 2025.02.09 |
| [Python] Einops (0) | 2025.02.09 |
| [Python] Einsum (0) | 2025.02.09 |