yeon's

[PyTorch] Class nn.Module 본문

Life hack

[PyTorch] Class nn.Module

yeonjins 2023. 1. 18. 19:07

nn.Module의 subclass를 신경망 모델로 활용하기 위한 메소드

1. __init__(self) : 신경망 모델에 활용할 모듈, 활성화 함수 등을 정의하고 초기화하는 메소드

torch.nn 모듈을 사용하려면 super.__init__()을 꼭 사용해야한다.

그렇지 않으면 AttributeError: cannot assign module before Module.__init__() call 이런 에러가 남

nn.Module=super

이 코드에서 torch.nn 모듈을 활용하지 않으면 에러가 나지 않는다.

따라서 super.__init__() 변수들을 상속받아 사용할 수 있도록 해준다.

super()는 부모클래스, __init__()은 부모클래스의 생성자를 부른다는 의미이다.

super()안에 파생클래스를 적어주는건 기능의 차이는 없고, 그냥 명확하게 적어두는 용이다.

 

 

2. forward(self, x) : init에서 정의한 것들을 연결하는 메소드, 모델에서 실행되어야 하는 계산

forward를 정의하고, backward()는 이용하면 알아서 backward 계산을 해주기 때문에 forward만 정의해두면 된다.

input을 넣으면 어떤 과정을 거쳐 output을 나오게 할지를 정의해주는 느낌이다.

 

 

model = Test() 로 인스턴스화 한 후 input을 넣어주면 모델이 동작한다.

 

보통 클래스에서 함수를 사용하는 방법은 model.forward() 처럼 호출해주는 것인데,

forward는 그냥 model 객체를 데이터와 함께 호출하면 자동으로 실행된다. (이유가 있음)

Comments