programing

Model.train()은 PyTorch에서 무엇을 합니까?

css3 2023. 6. 19. 21:54

Model.train()은 PyTorch에서 무엇을 합니까?

전화가 왔습니까?forward()nn.Module우리가 모델을 부를 때,forward메서드가 사용되고 있습니다.train()을 지정해야 하는 이유는 무엇입니까?

model.train()에서는 모델을 교육하고 있음을 알립니다.이는 교육 및 평가 중 다르게 동작하도록 설계된 Dropout 및 BatchNorm과 같은 계층에 정보를 제공하는 데 도움이 됩니다.예를 들어 교육 모드에서는 BatchNorm이 새 배치마다 이동 평균을 업데이트하는 반면 평가 모드에서는 이러한 업데이트가 동결됩니다.

자세한 정보:model.train()모드를 훈련으로 설정합니다(소스 코드 참조).다음 중 하나로 전화할 수 있습니다.model.eval()또는model.train(mode=False)당신이 테스트하고 있다고 말하는 것.예상하는 것은 다소 직관적입니다.train모델을 훈련시키는 기능을 하지만 그렇게 하지는 않습니다.그냥 모드를 설정합니다.

다음의 코드는 다음과 같습니다.

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

다음의 코드는 다음과 같습니다.

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

기본적으로,self.training플래그가 다음으로 설정됨True즉, 모듈은 기본적으로 열차 모드에 있습니다.언제self.training이라False모듈이 반대 상태인 평가 모드에 있습니다.

가장 일반적으로 사용되는 레이어 중에서 해당 플래그에만 관심이 있습니다.

model.train() model.eval()
모델을 교육 모드로 설정합니다.

BatchNorm도면층: 도면층별 통계량 사용
Dropout활성화된 레이어 등
모델을 평가(추측) 모드로 설정합니다.

BatchNorm도면층:실행 중인 통계 사용
Dropout비활성화된 레이어 등
와 동등한model.train(False).

참고: 이러한 함수 호출은 전진/후진 패스를 실행하지 않습니다.그들은 모델에게 실행할 때 행동하는 방법을 알려줍니다.

이는 일부 모듈(레이어)(예:Dropout,BatchNorm)는 훈련 중과 추론 중에 다르게 동작하도록 설계되었으며, 따라서 모델이 잘못된 모드에서 실행되면 예상치 못한 결과를 생성합니다.

모형에 사용자의 의도를 알리는 두 가지 방법이 있습니다. 즉, 모형을 훈련시킬 것인가 아니면 모형을 사용하여 평가할 것인가입니다.의 경우model.train()모델은 레이어와 사용 시기를 학습해야 한다는 것을 알고 있습니다.model.eval()이는 모델이 새로운 것을 학습할 수 없음을 나타내며 모델은 테스트에 사용됩니다. model.eval()만약 우리가 배치 표준을 사용하고 있다면, 그리고 우리가 단지 하나의 이미지를 통과하기를 원한다면, pytorch는 또한 필요하기 때문에, pytorch는 다음과 같은 오류를 던집니다.model.eval()지정되지 않았습니다.

다음 모델을 고려합니다.

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training) #Look here
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

여기서, 의 기능은dropout작동 모드에 따라 다릅니다.보다시피, 그것은 오직 다음과 같은 경우에만 작동합니다.self.training==True그서래를 하면 할때력입.model.train()을 수행하지 예: " 모의전함드수수롭을웃행아다니합는진델수▁when다행▁the▁will니합을say웃▁otherwise▁drop모'out드아롭▁perform"). 그렇지 않으면 수행되지 않습니다(예: 언제).model.eval()또는model.train(mode=False)).

현재 공식 문서에는 다음이 명시되어 있습니다.

이는 특정 모듈에만 [sic] 영향을 미칩니다.교육/평가 모드에서 해당 모듈의 동작에 대한 자세한 내용은 특정 모듈의 문서를 참조하십시오(예: 탈락, 배치 기준 등).

언급URL : https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch