Urban Sound Classification

뉴욕 대학교 MARL(Music and Audio Research Lab)에서 2014 년에 공개한 UrbanSound8K 데이터를 이용하여 텐서플로우를 사용해 사운드를 분류 모델을 만들어 보았습니다. 텐서플로우 코드는 아퀴브 사이드(Aaqib Saeed)의 블로그를 참고하였습니다.

UrbanSound8K 데이터는 모두 10 가지 종류의 소리를 4 초 가량 녹음한 것으로 8 천개가 넘는 wav 파일들 입니다. 소리의 종류는 ‘air_conditioner’, ‘car_horn’, ‘children_playing’, ‘dog_bark’, ‘drilling’, ‘engine_idling’, ‘gun_shot’, ‘jackhammer’, ‘siren’, ‘street_music’ 입니다. 압축을 풀기 전의 데이터 사이즈가 5 GB 가 넘습니다. 이 데이터를 받기 위해서는 깃허브에서 다운로드 주소를 이용해 다운받을 수 있지만 사용 목적에 대해 알려달라고 합니다.

먼저 wav 파일을 이용해서 피처를 뽑아 내야합니다. 아퀴브 사이드가 사용한 방법은 사운드 분석 파이썬 라이브러리인 librosa 를 이용해서 특성을 추출하였습니다. 사용된 특성은 mfcc(Mel-frequency cepstral coefficients), chroma_stft(chromagram from a waveform or power spectrogram), melspectrogram(Mel-scaled power spectrogram), spectral_contrast(spectral contrast), tonnetz(tonal centroid features) 입니다. 이 데이터를 모두 행으로 늘여 놓으면 총 193 개의 입력 데이터가 만들어 집니다.

오디오 파일이 많다보니 이 데이터를 주피터 노트북에서 생성하고 모델을 만들기가 부담스러운 작업입니다. 별도의 파이썬 스크립트로 입력 데이터를 미리 가공하여 하나의 파일로 합쳐 놓은 urban_sound.npz 파일을 제 깃허브에 올려 놓았습니다. 이 파일은 입력 데이터로 만들어진 넘파이 배열을 파일로 저장해 놓은 것입니다. 특성을 추출하여 npz 데이터를 생성한 코드는 깃허브의 feature_extraction.py 와 feature_merge.py 파일을 참고하세요.

npz 파일을 사용하면 모델을 만들고 테스트할 때 오디오 파일을 매번 가공할 필요 없이 npz 파일로 부터 바로 입력 특성을 로드할 수 있습니다.

sound_data = np.load('urban_sound.npz')

모델 훈련은 20% 를 테스트 데이터로 떼어내고 20% 를 다시 밸리데이션 데이터로 분리하였습니다. 노트북에서 돌리다 보니 속도가 느려 크로스 밸리데이션은 사용하지 않았습니다. 밸리데이션 데이터로 하이퍼파라메타를 적절히 고르고 훈련 데이터와 밸리데이션 데이터를 합쳐서 최종 모델을 훈련시키고 테스트 데이터로 모델을 평가하였습니다.

사용한 뉴럴 네트워크는 완전 연결 뉴럴 네트워크로 3 개의 히든 레이어에 각각 뉴런을 300 개, 200 개, 100 개를 설정하였습니다.

n_hidden_units_one = 300
n_hidden_units_two = 200
n_hidden_units_three = 100

파라메타 초기 값은 입력 데이터의 크기인 193 의 제곱근의 역수를 사용했는데 각 레이어의 히든 유닛 수나 활성화 함수에 맞게 다르게 했으면 더 좋았을 것 같았습니다. 3개의 히든 레이어와 출력 레이어에 대한 텐서플로우 그래프를 아래와 같이 생성하였습니다. 첫번째 세번째의 활성화 함수는 시그모이드를 사용하였고 두번째 활성화 함수는 하이퍼볼릭 탄젠트 함수를 사용했습니다.

X = tf.placeholder(tf.float32,[None,n_dim])
Y = tf.placeholder(tf.float32,[None,n_classes])

W_1 = tf.Variable(tf.random_normal([n_dim, n_hidden_units_one], mean=0, stddev=sd), name="w1")
b_1 = tf.Variable(tf.random_normal([n_hidden_units_one], mean=0, stddev=sd), name="b1")
h_1 = tf.nn.sigmoid(tf.matmul(X, W_1) + b_1)

W_2 = tf.Variable(tf.random_normal([n_hidden_units_one, n_hidden_units_two], mean=0, stddev=sd), name="w2")
b_2 = tf.Variable(tf.random_normal([n_hidden_units_two], mean=0, stddev=sd), name="b2")
h_2 = tf.nn.tanh(tf.matmul(h_1, W_2) + b_2)

W_3 = tf.Variable(tf.random_normal([n_hidden_units_two, n_hidden_units_three], mean=0, stddev=sd), name="w3")
b_3 = tf.Variable(tf.random_normal([n_hidden_units_three], mean=0, stddev=sd), name="b3")
h_3 = tf.nn.sigmoid(tf.matmul(h_2, W_3) + b_3)

W = tf.Variable(tf.random_normal([n_hidden_units_three, n_classes], mean=0, stddev=sd), name="w")
b = tf.Variable(tf.random_normal([n_classes], mean = 0, stddev=sd), name="b")
y_ = tf.nn.softmax(tf.matmul(h_3, W) + b)

비용 함수는 크로스 엔트로피를 사용하고 그래디언트 디센트 옵티마이저를 사용하였습니다. 모델의 반복은 총 6000 회를 진행했습니다. 최종 테스트 결과는 대략 86.7% 의 정확도를 내었습니다.(아래 코드에 Validation accuracy 라고 되어 있는데 최종 테스트 정확도를 계산하면서 수정하지 못했습니다. Test accuracy 라고 이해해 주세요.)

cost_history = np.empty(shape=[1],dtype=float)
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        _,cost = sess.run([optimizer, cost_function], feed_dict={X: X_sub, Y: y_sub})
        cost_history = np.append(cost_history,cost)
 
    print('Validation accuracy: ',round(sess.run(accuracy, feed_dict={X: X_test, Y: y_test}) , 3))
    saver.save(sess, "model_321.ckpt")
Validation accuracy:  0.867

비용 값의 감소 그래프를 그려 보면 비교적 요동이 심한 것으로 보입니다. 이 요동의 원인을 잡아낼 수 있다면 좀 더 정확도를 높일 수 있지 않을까 생각됩니다.

urban-sound-learning

다른 곳에서 구한 사운드 데이터로 테스트해 보니 강아지 소리는 잘 구분해 냈는데 드릴 소리는 절반 정도를 구분하지 못했습니다. 드릴 소리를 스트리트 뮤직이라고 판단하는 경우가 많았습니다.

만들어진 모델 파라메타는 model_321.ckpt 파일에 저장하여 깃허브에 올려 놓았습니다. 필요하신 분이 있다면 자유롭게 사용하시고 혹 더 좋은 팁이 있다면 피드백 부탁 드립니다.

 

답글 남기기

아래 항목을 채우거나 오른쪽 아이콘 중 하나를 클릭하여 로그 인 하세요:

WordPress.com 로고

WordPress.com의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Twitter 사진

Twitter의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Facebook 사진

Facebook의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Google+ photo

Google+의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

%s에 연결하는 중