Source code for ASTROMER.core.classifier

import tensorflow as tf
import logging
import json
import os

from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from tensorflow.keras.layers import LSTM, Dense, BatchNormalization, LayerNormalization
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras import Model
from ASTROMER.core.astromer import get_ASTROMER
from ASTROMER.core.metrics import custom_acc
from ASTROMER.core.tboard import save_scalar
from ASTROMER.core.losses import custom_bce
from ASTROMER.core.output import SauceLayer
from tqdm import tqdm

from ASTROMER.core.data import standardize
logging.getLogger('tensorflow').setLevel(logging.ERROR)  # suppress warnings

[docs]def get_fc_attention(units, num_classes, weigths): ''' FC + ATT''' conf_file = os.path.join(weigths, 'conf.json') with open(conf_file, 'r') as handle: conf = json.load(handle) model = get_ASTROMER(num_layers=conf['layers'], d_model =conf['head_dim'], num_heads =conf['heads'], dff =conf['dff'], base =conf['base'], dropout =conf['dropout'], maxlen =conf['max_obs'], use_leak =conf['use_leak']) weights_path = '{}/weights'.format(weigths) model.load_weights(weights_path) encoder = model.get_layer('encoder') encoder.trainable = False mask = 1.-encoder.input['mask_in'] x = encoder(encoder.input) x = x * mask x = tf.reduce_sum(x, 1)/tf.reduce_sum(mask, 1) x = Dense(1024, name='FCN1')(x) x = Dense(512, name='FCN2')(x) x = Dense(num_classes, name='FCN3')(x) return Model(inputs=encoder.input, outputs=x, name="FCATT")
[docs]def get_lstm_no_attention(units, num_classes, maxlen, dropout=0.5): ''' LSTM + LSTM + FC''' serie = Input(shape=(maxlen, 1), batch_size=None, name='input') times = Input(shape=(maxlen, 1), batch_size=None, name='times') mask = Input(shape=(maxlen, 1), batch_size=None, name='mask') length = Input(shape=(maxlen,), batch_size=None, dtype=tf.int32, name='length') placeholder = {'input':serie, 'mask_in':mask, 'times':times, 'length':length} bool_mask = tf.logical_not(tf.cast(placeholder['mask_in'], tf.bool)) x = tf.concat([placeholder['times'], placeholder['input']], 2) x = LSTM(units, return_sequences=True, dropout=dropout, name='RNN_0')(x, mask=bool_mask) x = LayerNormalization(axis=1)(x) x = LSTM(units, return_sequences=True, dropout=dropout, name='RNN_1')(x, mask=bool_mask) x = LayerNormalization(axis=1)(x) x = Dense(num_classes, name='FCN')(x) return Model(inputs=placeholder, outputs=x, name="RNNCLF")
[docs]def get_lstm_attention(units, num_classes, weigths, dropout=0.5): ''' ATT + LSTM + LSTM + FC''' conf_file = os.path.join(weigths, 'conf.json') with open(conf_file, 'r') as handle: conf = json.load(handle) model = get_ASTROMER(num_layers=conf['layers'], d_model =conf['head_dim'], num_heads =conf['heads'], dff =conf['dff'], base =conf['base'], dropout =conf['dropout'], maxlen =conf['max_obs'], use_leak =conf['use_leak']) weights_path = '{}/weights'.format(weigths) model.load_weights(weights_path) encoder = model.get_layer('encoder') encoder.trainable = False bool_mask = tf.logical_not(tf.cast(encoder.input['mask_in'], tf.bool)) x = encoder(encoder.input) x = tf.reshape(x, [-1, conf['max_obs'], encoder.output.shape[-1]]) x = LayerNormalization()(x) x = LSTM(units, return_sequences=True, dropout=dropout, name='RNN_0')(x, mask=bool_mask) x = LayerNormalization()(x) x = LSTM(units, return_sequences=True, dropout=dropout, name='RNN_1')(x, mask=bool_mask) x = LayerNormalization()(x) x = Dense(num_classes, name='FCN')(x) return Model(inputs=encoder.input, outputs=x, name="RNNCLF")
[docs]@tf.function def train_step(model, batch, opt): with tf.GradientTape() as tape: y_pred = model(batch) ce = custom_bce(y_true=batch['label'], y_pred=y_pred) acc = custom_acc(batch['label'], y_pred) grads = tape.gradient(ce, model.trainable_weights) opt.apply_gradients(zip(grads, model.trainable_weights)) return acc, ce
[docs]@tf.function def valid_step(model, batch, return_pred=False): with tf.GradientTape() as tape: y_pred = model(batch, training=False) ce = custom_bce(y_true=batch['label'], y_pred=y_pred) acc = custom_acc(batch['label'], y_pred) if return_pred: return acc, ce, y_pred, batch['label'] return acc, ce
[docs]def train(model, train_batches, valid_batches, patience=20, exp_path='./experiments/test', epochs=1, lr=1e-3, verbose=1): # Tensorboard train_writter = tf.summary.create_file_writer( os.path.join(exp_path, 'logs', 'train')) valid_writter = tf.summary.create_file_writer( os.path.join(exp_path, 'logs', 'valid')) # Optimizer optimizer = tf.keras.optimizers.Adam(lr) # To save metrics train_bce = tf.keras.metrics.Mean(name='train_bce') valid_bce = tf.keras.metrics.Mean(name='valid_bce') train_acc = tf.keras.metrics.Mean(name='train_acc') valid_acc = tf.keras.metrics.Mean(name='valid_acc') # ============================== # ======= Training Loop ======== # ============================== best_loss = 999999. es_count = 0 pbar = tqdm(range(epochs), desc='epoch') for epoch in pbar: for train_batch in train_batches: acc, bce = train_step(model, train_batch, optimizer) train_acc.update_state(acc) train_bce.update_state(bce) for valid_batch in valid_batches: acc, bce = valid_step(model, valid_batch) valid_acc.update_state(acc) valid_bce.update_state(bce) save_scalar(train_writter, train_acc, epoch, name='accuracy') save_scalar(valid_writter, valid_acc, epoch, name='accuracy') save_scalar(train_writter, train_bce, epoch, name='xentropy') save_scalar(valid_writter, valid_bce, epoch, name='xentropy') if valid_bce.result() < best_loss: best_loss = valid_bce.result() es_count = 0. model.save_weights(os.path.join(exp_path, 'weights')) else: es_count+=1. if es_count == patience: print('[INFO] Early Stopping Triggered') break msg = 'EPOCH {} - ES COUNT: {}/{} Train acc: {:.4f} - Val acc: {:.4f} - Train CE: {:.2f} - Val CE: {:.2f}'.format( epoch, es_count, patience, train_acc.result(), valid_acc.result(), train_bce.result(), valid_bce.result()) pbar.set_description(msg) valid_bce.reset_states() train_bce.reset_states() train_acc.reset_states() valid_acc.reset_states()
[docs]def get_conf(path): conf_file = os.path.join(path, 'conf.json') with open(conf_file, 'r') as handle: conf = json.load(handle) return conf
[docs]def load_weights(model, weigths): weights_path = '{}/weights'.format(weigths) model.load_weights(weights_path) return model
[docs]def predict(model, test_batches): predictions = [] true_labels = [] for batch in tqdm(test_batches, desc='test'): acc, ce, y_pred, y_true = valid_step(model, batch, return_pred=True) if len(y_pred.shape)>2: predictions.append(y_pred[:, -1, :]) else: predictions.append(y_pred) true_labels.append(y_true) y_pred = tf.concat(predictions, 0) y_true = tf.concat(true_labels, 0) pred_labels = tf.argmax(y_pred, 1) precision, \ recall, \ f1, _ = precision_recall_fscore_support(y_true, pred_labels, average='macro') acc = accuracy_score(y_true, pred_labels) results = {'f1': f1, 'recall': recall, 'precision': precision, 'accuracy':acc, 'y_true':y_true, 'y_pred':pred_labels} return results
[docs]def predict_from_path(path, test_batches, mode=0, save=False): conf_rnn = get_conf(path) if mode == 0: clf = get_lstm_attention(conf_rnn['units'], conf_rnn['num_classes'], conf_rnn['w'], conf_rnn['dropout']) if mode == 1: clf = get_fc_attention(conf_rnn['units'], conf_rnn['num_classes'], conf_rnn['w']) if mode == 2: clf = get_lstm_no_attention(conf_rnn['units'], conf_rnn['num_classes'], conf_rnn['max_obs'], conf_rnn['dropout']) clf = load_weights(clf, path) results = predict(clf, test_batches) return results