Snakeの活性化関数

目次

1. Snake活性化関数の概要
1.1 Snake活性化関数とは
1.2 Snake関数
2. 実験
2.1 ライブラリインポート
2.2 データ読み込み
2.3 データ加工
2.4 Snakeの活性化関数を作成
2.5 LSTMの活性化関数を作成
2.6 まとめ

関連記事:活性化関数のまとめ

 

1. Snake活性化関数の概要

1.1 Snake活性化関数とは

Snakeの活性化関数は、単純な周期関数の外挿を学習できないLSTMに用います。通常の活性化関数であるtanh、sigmoid、reluの弱点を改善するために使用します。LSTMベースのアクティベーションの優れた最適化特性を維持しながら、周期関数を学習するために必要な周期的誘導バイアスを実現する新しいアクティベーション、つまりx + sin2(x)のような学習が難しいタイプで用いられます。

 

tfa.activations.snake(
x: tfa.types.TensorLike,
frequency: tfa.types.Number = 1
) -> tf.Tensor

論文:Neural Networks Fail to Learn Periodic Functions and How to Fix It

https://arxiv.org/abs/2006.08195

TensorFlow: https://www.tensorflow.org/addons/api_docs/python/tfa/activations/snake

Pytorch: https://pypi.org/project/torch-snake/

 

2. 実験

データセット:3,235日の日光のデータセット

モデル:Snakeの活性化関数のモデル vs LSTMの活性化関数のモデル

モデル評価:MAE

 

2.1 ライブラリインポート

ライブラリのインストール

!pip install tensorflow-addons

ライブラリインポート

import os

import sys

import csv

import tensorflow as tf

import numpy as np

import urllib

import matplotlib.pyplot as plt

keras = tf.keras

 

from datetime import datetime

 

2.2 データ読み込み

データセットをロードして、正規化します。

# Load data

 

url = ‘https://storage.googleapis.com/download.tensorflow.org/data/Sunspots.csv’

urllib.request.urlretrieve(url, ‘sunspots.csv’)

 

time_step = []

sunspots = []

 

with open(‘sunspots.csv’) as csvfile:

reader = csv.reader(csvfile, delimiter=’,’)

next(reader)

for row in reader:

sunspots.append(float(row[2]))

time_step.append(int(row[0]))

 

# normalization function

 

series = np.array(sunspots)

min = np.min(series)

max = np.max(series)

series -= min

series /= max

time = np.array(time_step)

 

学習データとテストデーとを作成します。

# train test split

split_time = 3000

 

time_train = time[:split_time]

x_train = series[:split_time]

time_valid = time[split_time:]

x_valid = series[split_time:]

 

学習データを確認します。

def plot_series(time, series, format=”-“, start=0, end=None, label=None):

plt.plot(time[start:end], series[start:end], format, label=label)

plt.xlabel(“Time”)

plt.ylabel(“Value”)

if label:

plt.legend(fontsize=14)

plt.grid(True)

 

plt.figure(figsize=(10, 6))

plot_series(time_train, x_train)

plt.show()

テストデータを確認します。

plt.figure(figsize=(10, 6))

plot_series(time_valid, x_valid)

plt.show()

 

2.4 Snakeの活性化関数のネットワークを作成

# DENSE

import tensorflow_addons as tfa

 

# Parameter

window_size = 30

batch_size = 256

shuffle_buffer_size = 1000

epochs = 150

 

keras.backend.clear_session()

tf.random.set_seed(42)

np.random.seed(42)

 

dense_model = keras.models.Sequential([

keras.layers.Dense(100, batch_input_shape=[1, None, 1]),

keras.layers.Dense(100, activation=tfa.activations.snake,),

keras.layers.Dense(100, activation=tfa.activations.snake,),

keras.layers.Dense(100, activation=tfa.activations.snake,),

keras.layers.Dense(1),

keras.layers.Lambda(lambda x: x * 200.0)

])

 

optimizer = keras.optimizers.SGD(learning_rate=5e-7, momentum=0.9)

dense_model.compile(loss=keras.losses.Huber(),

optimizer=optimizer,

metrics=[“mae”])

 

reset_states = ResetStatesCallback()

checkpoint = “output/dense/my_checkpoint.h5”

dense_checkpoint = keras.callbacks.ModelCheckpoint(checkpoint, save_best_only=True)

early_stopping = keras.callbacks.EarlyStopping(patience=5)

 

dense_history = dense_model.fit(train_set,

epochs=epochs,

validation_data=valid_set,

callbacks=[early_stopping, model_checkpoint, reset_states])

Epoch 1/150

99/99 [==============================] – 2s 9ms/step – loss: 0.1945 – mae: 0.4653 – val_loss: 0.0918 – val_mae: 0.3777

Epoch 150/150

99/99 [==============================] – 1s 7ms/step – loss: 0.0025 – mae: 0.0523 – val_loss: 0.0019 – val_mae: 0.0465

 

MAEのモデル評価

モデルは徐々に学ぶことができます。

# Plot loss

plt.xlabel(‘Epoch’)

plt.ylabel(‘mae’)

plt.plot(dense_history.history[‘val_loss’])

モデル評価

dense_forecast = dense_model.predict(series[np.newaxis, :, np.newaxis])

dense_forecast = dense_forecast[0, split_time – 1:-1, 0]

 

print(‘MAE: ‘ + str(keras.metrics.mean_absolute_error(x_valid, dense_forecast).numpy()))

MAE: 0.044753935

 

モデルを推論します。

過学習になる傾向が知られています。

def plot_series(time, series, format=”-“, start=0, end=None, label=None):

plt.plot(time[start:end], series[start:end], format, label=label)

plt.xlabel(“Time”)

plt.ylabel(“Value”)

if label:

plt.legend(fontsize=14)

plt.grid(True)

 

plt.figure(figsize=(10, 6))

plot_series(time_valid, x_valid)

plot_series(time_valid, dense_forecast)

# plt.savefig(‘output/rnn/predict_vs_valid.png’)

plt.show()

 

2.5 LSTMの活性化関数を作成

# LSTM

 

# Parameter

window_size = 30

batch_size = 256

shuffle_buffer_size = 1000

epochs = 100

 

def sequential_window_dataset(series, window_size):

series = tf.expand_dims(series, axis=-1)

ds = tf.data.Dataset.from_tensor_slices(series)

ds = ds.window(window_size + 1, shift=window_size, drop_remainder=True)

ds = ds.flat_map(lambda window: window.batch(window_size + 1))

ds = ds.map(lambda window: (window[:-1], window[1:]))

return ds.batch(1).prefetch(1)

 

class ResetStatesCallback(keras.callbacks.Callback):

def on_epoch_begin(self, epoch, logs):

self.model.reset_states()

 

keras.backend.clear_session()

tf.random.set_seed(42)

np.random.seed(42)

 

window_size = 30

train_set = sequential_window_dataset(x_train, window_size)

valid_set = sequential_window_dataset(x_valid, window_size)

 

lstm_model = keras.models.Sequential([

keras.layers.LSTM(100, return_sequences=True, stateful=True,

batch_input_shape=[1, None, 1]),

keras.layers.LSTM(100, return_sequences=True, stateful=True),

keras.layers.Dense(1),

keras.layers.Lambda(lambda x: x * 200.0)

])

 

optimizer = keras.optimizers.SGD(learning_rate=5e-7, momentum=0.9)

lstm_model.compile(loss=keras.losses.Huber(),

optimizer=optimizer,

metrics=[“mae”])

 

reset_states = ResetStatesCallback()

checkpoint = “output/lstm/my_checkpoint.h5”

model_checkpoint = keras.callbacks.ModelCheckpoint(checkpoint, save_best_only=True)

early_stopping = keras.callbacks.EarlyStopping(patience=3)

 

lstm_history = lstm_model.fit(train_set,

epochs=epochs,

validation_data=valid_set,

callbacks=[early_stopping, model_checkpoint, reset_states])

Epoch 1/100

99/99 [==============================] – 6s 23ms/step – loss: 1.5726 – mae: 2.0261 – val_loss: 1.0051 – val_mae: 1.4602

Epoch 100/100

99/99 [==============================] – 1s 14ms/step – loss: 0.0026 – mae: 0.0540 – val_loss: 0.0025 – val_mae: 0.0495

 

MAEのモデル評価

# Plot loss

plt.xlabel(‘Epoch’)

plt.ylabel(‘mae’)

plt.plot(lstm_history.history[‘loss’])

plt.savefig(‘output/lstm/loss.png’)

# Forecast

 

lstm_forecast = lstm_model.predict(series[np.newaxis, :, np.newaxis])

lstm_forecast = lstm_forecast[0, split_time – 1:-1, 0]

 

print(‘MAE: ‘ + str(keras.metrics.mean_absolute_error(x_valid, lstm_forecast).numpy()))

MAE: 0.042137478

 

モデルを推論します。

def plot_series(time, series, format=”-“, start=0, end=None, label=None):

plt.plot(time[start:end], series[start:end], format, label=label)

plt.xlabel(“Time”)

plt.ylabel(“Value”)

if label:

plt.legend(fontsize=14)

plt.grid(True)

 

plt.figure(figsize=(10, 6))

plot_series(time_valid, x_valid)

plot_series(time_valid, lstm_forecast)

# plt.savefig(‘output/lstm/predict_vs_valid.png’)

plt.show()

 

2.6 まとめ

日光のデータセットで、Snakeの活性化関数のモデル と LSTMのモデルを作成しました。モデルの精度は同じくらいですが、上手く周期部分をSnakeが捉えられていることがわかります。

 

担当者:HM

香川県高松市出身 データ分析にて、博士(理学)を取得後、自動車メーカー会社にてデータ分析に関わる。その後コンサルティングファームでデータ分析プロジェクトを歴任後独立 気が付けばデータ分析プロジェクトだけで50以上担当

理化学研究所にて研究員を拝命中 応用数理学会所属