Pipeline에서 캐싱을 사용하기

며칠전 ‘파이썬 라이브러리를 활용한 머신러닝‘의 원저자 안드리아스 뮐러가 한 유투브 채널에 나와 Scikit-Learn의 0.19버전에서 추가된 기능과 0.20에서 추가될 내용을 소개했습니다.

0.19에 새롭게 추가된 기능으로 대표적으로 언급한 것이 Pipeline 캐싱, SAGA solver, RepeatedKFold, QuantileTransformer, ClassifierChain, 다중 scoring 설정입니다. 또 0.20에 추가될 기능으로는 GradientBoosting의 early stopping, CategoricalEncoder, PowerTransfomer, ColumnTransformer, OpenML 데이터셋 로더, matplotlib 기반의 트리 그래프 등입니다. 시간날 때마다 차례대로 살펴 보도록 하겠습니다. 먼저 오늘은 파이프라인 캐싱입니다!

파이프라인 캐싱은 파이프라인의 단계마다 transformer의 fit 결과를 저장하여 다시 사용한다는 것입니다. 전처리 단계가 복잡한 그리드서치를 사용할 때 특히 유용할 수 있습니다. 어 그럼 지금까지는 캐싱없이 무식하게 반복적으로 전처리를 매번 다시 했다는 건가요? 네 맞습니다. ㅠ.ㅠ

캐싱을 사용하려면 Pipeline 클래스의 memory 매개변수에 캐싱에 사용할 디렉토리를 지정하기만 하면 됩니다. 간단한 예를 살펴 보겠습니다. 책의 예제와 비슷하게 보스턴 주택가격 데이터셋을 로드하여 훈련 데이터를 준비하고 매개변수 탐색을 위한 그리드를 정의합니다.

boston = load_boston()
X_train, X_test, y_train, y_test = train_test_split(boston.data, boston.target,
                                                    random_state=0)
param_grid = {'polynomialfeatures__degree': [1, 2, 3, 4, 5],
              'ridge__alpha': [0.001, 0.01, 0.1, 1, 10, 100]}

그 다음은 파이프라인을 만들고 GridSearchCV로 매개변수 탐색을 하면 됩니다. 파이프라인 클래스를 사용해서 Pipeline(memory=’…’)와 같이 사용해도 되고 파이프라인을 간단하게 만들어 주는 make_pipeline(memory=’…’) 함수를 사용할 수도 있습니다. 캐싱에 사용할 임시 디렉토리를 지정해 주기위해 mkdtemp와 rmtree 함수를 임포트하고 임시 디렉토리를 만듭니다.

from tempfile import mkdtemp
from shutil import rmtree
cache_dir = mkdtemp()

다음 make_pipeline으로 파이프라인 단계를 만들고 GridSearchCV로 매개변수 탐색을 합니다.

pipe = make_pipeline(StandardScaler(), PolynomialFeatures(), Ridge(),
                     memory=cache_dir)
grid = GridSearchCV(pipe, param_grid=param_grid, cv=5, n_jobs=-1)
grid.fit(X_train, y_train)

간단하죠? 작업이 끝나고 난 뒤에는 임시 디렉토리를 지워 줍니다.

rmtree(cache_dir)

전처리 단계가 복잡하고 많을 수록 캐싱의 효과는 큽니다. 이 샘플 코드를 캐싱을 사용하지 않고 실행했을 때와 비교하면 10.7초에서 6.57초로 30%이상 속도가 빨라 졌습니다!

이 포스트에 사용한 전체 코드는 깃허브(https://github.com/rickiepark/introduction_to_ml_with_python/blob/master/Pipeline-cache.ipynb)에서 확인할 수 있습니다.

답글 남기기

아래 항목을 채우거나 오른쪽 아이콘 중 하나를 클릭하여 로그 인 하세요:

WordPress.com 로고

WordPress.com의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Twitter 사진

Twitter의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Facebook 사진

Facebook의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Google+ photo

Google+의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

%s에 연결하는 중