Big Ben
Big Ben's Log
Big Ben
전체 방문자
오늘
어제
  • 전체 글 (80)
    • 파이썬 (23)
      • 파이썬 기초 (5)
      • 클래스 (6)
      • 자료구조 (4)
      • Tensorflow (3)
      • PyTorch (2)
      • konlpy (1)
      • anaconda (1)
    • 머신러닝 (3)
      • 선형회귀 (1)
      • Tree 기반 (1)
    • 딥러닝 (6)
      • NLP (2)
      • VISION (2)
      • TABULAR (0)
      • 딥러닝 서버 구축 (2)
    • 그래프 이론 (1)
      • 그래프마이닝 (1)
      • GNN (0)
    • 강화학습 (3)
      • 강화학습 기본 (3)
    • 인공지능 (5)
    • 추천시스템 (2)
      • 추천시스템 기초 (2)
    • Competitions (1)
    • 빅데이터 (8)
      • 하둡 (3)
      • 스파크 (4)
      • 클라우드 (1)
    • SQL (7)
      • MariaDB (2)
    • 논문 리뷰 (2)
    • 대학원 (0)
      • 데이터 사이언스 (0)
      • 경제학 (0)
    • 선형대수학 (7)
      • 선형대수 ICE BREAKING (1)
      • 벡터 (5)
      • 고윳값 (1)
    • 개인프로젝트 (0)
      • 포트폴리오 대시보드 + AI기반 주식 자동매매 (0)
    • 재테크 (1)
    • 자동차 (0)
    • 알고리즘 (11)

블로그 메뉴

  • 홈
  • 태그
  • 미디어로그
  • 위치로그
  • 방명록

공지사항

인기 글

태그

  • AI
  • 데이터사이언스
  • 백준
  • pytorch
  • 프로그래밍
  • 딥러닝
  • 객체
  • 데이터
  • 객체지향
  • 하둡
  • TensorFlow
  • sql
  • 데이터베이스
  • 머신러닝
  • Baekjoon
  • 파이썬
  • 파이썬기초
  • 인공지능
  • 빅데이터
  • 선형대수학
  • PYTHON
  • 코테
  • class
  • mysql
  • 선형대수
  • 프로그래머스
  • 코딩테스트
  • MariaDB
  • 자료구조
  • 알고리즘

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
Big Ben

Big Ben's Log

[pytorch]  torch.tensor.detach() 의 기능
파이썬/PyTorch

[pytorch] torch.tensor.detach() 의 기능

2023. 4. 6. 23:05
반응형

파이토치에서 Tensor 객체의 detach() 메소드는 현재 Tensor 객체와 동일한 데이터를 가지지만 연산 그래프(Computational Graph)에서 분리된 새로운 Tensor 객체를 생성합니다.

 

이 메소드는 일반적으로 Tensor 객체를 다른 Tensor 객체로 변환하고자 할 때 사용됩니다.

 

예를 들어, 주어진 Tensor 객체에 대한 연산의 결과로 생성된 새로운 Tensor 객체가 있을 때, 이 새로운 Tensor 객체를 사용하여 추가적인 계산을 수행하고자 할 때, 기존 Tensor 객체의 연산 그래프와의 의존성을 제거하여 메모리 사용량을 줄이고 계산 속도를 향상시키는 데 유용합니다.

 

detach() 메소드는 requires_grad 속성을 False로 설정하여 기존 Tensor 객체와 다르게 자동 미분 기능에서 제외됩니다.

 

따라서 detach() 메소드를 사용하여 생성된 Tensor 객체는 그라디언트(gradient) 계산에 사용되지 않으며, 그라디언트를 계산하지 않는 모델에서 중간 출력 값을 얻을 때 특히 유용합니다.

 

예를 들어, 다음과 같이 Tensor 객체 a와 b가 있을 때, detach() 메소드를 사용하여 Tensor 객체 b를 새로운 Tensor 객체 c로 분리하고자 할 때 다음과 같이 작성할 수 있습니다.

 

import torch 

a = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 
b = 2 * a 
c = b.detach()

위 코드에서, Tensor 객체 b는 Tensor 객체 a를 사용하여 생성되었으며, requires_grad 속성이 True로 설정되어 있습니다.

 

그러나 detach() 메소드를 사용하여 생성된 Tensor 객체 c는 requires_grad 속성이 False로 설정되어 있습니다.

따라서 c에 대한 연산은 a와의 의존성을 가지지 않으며, c로부터 추가적인 Tensor 객체를 생성할 때 메모리 사용량이 줄어들게 됩니다.

 

좀더 직관적인 예제를 한번 살펴보겠습니다.

 

위 그림은 아래의 코드를 통해 이해하면 쉽습니다.

import torch
import torch.nn as nn

class Test(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        
    def forward(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1.detach())   # detach 사용
        return out2
        
model = Test()

 

forward 함수를 살펴보면, layer1에서 나온 output이 detach되는 것을 볼 수 있습니다. 

이 경우 역전파 때 gradient가 이전 layer인 layer1으로 흘러가지 않습니다.

반응형
저작자표시

'파이썬 > PyTorch' 카테고리의 다른 글

[pytorch] 파이토치 opencv, mxnet, torchmetrics 설치 시 gpu 인식 불가 이슈 해결  (0) 2023.04.06
    '파이썬/PyTorch' 카테고리의 다른 글
    • [pytorch] 파이토치 opencv, mxnet, torchmetrics 설치 시 gpu 인식 불가 이슈 해결
    Big Ben
    Big Ben
    Data Scientist

    티스토리툴바