Source code for finetune.target_models.mtl

import logging

import tensorflow as tf
from finetune.base import BaseModel
from finetune.input_pipeline import BasePipeline
from finetune.errors import FinetuneError

from finetune.target_models.sequence_labeling import SequenceLabeler
from finetune.encoding.sequence_encoder import indico_to_finetune_sequence

LOGGER = logging.getLogger('finetune')

def get_input_fns(task_id, input_fn, validation_fn):
    def fn(x, y):
        return (
                "tokens": reshape_to_rank_4(x["tokens"]),
                "mask": x["mask"],
                "task_id": task_id

    return lambda: input_fn().map(fn), lambda: validation_fn().map(fn)

def reshape_to_rank_4(t):
    s = tf.shape(t)
    return tf.reshape(t, [s[0], -1, s[-2], s[-1]])

def get_train_eval_dataset(input_fn, val_size):
    def fn():
        return input_fn().take(val_size)

    return fn

class MultiTaskPipeline(BasePipeline):

    def __init__(self, *args, **kwargs):
        super(MultiTaskPipeline, self).__init__(*args, **kwargs)
        self.dataset_size_ = 0
        self.loss_weights = None
        self.target_dim = -1
        self.input_pipelines = None

    def dataset_size(self):
        return self.dataset_size_

    def get_train_input_fns(self, Xs, Y=None, batch_size=None, val_size=None):
        val_funcs = {}
        val_sizes = {}
        val_intervals = {}
        input_pipelines = {}
        frequencies = []
        input_funcs = []

        for task_name in self.config.tasks:
            input_pipelines[task_name] = self.config.tasks[task_name]._get_input_pipeline(self)
            task_tuple = input_pipelines[task_name].get_train_input_fns(
            self.dataset_size_ += self.config.dataset_size

            (val_func, input_func, val_sizes[task_name], val_intervals[task_name]) = task_tuple
            task_id = self.config.task_name_to_id[task_name]

            input_func_normalised, val_func_normalised = get_input_fns(task_id, input_func, val_func)
            val_funcs[task_name] = val_func_normalised
            val_funcs[task_name + "_train"] = get_train_eval_dataset(input_func_normalised, val_sizes[task_name])

        sum_frequencies = sum(frequencies)
        weights = [float(w) / sum_frequencies for w in frequencies]
        train_dataset = lambda:[f() for f in input_funcs], weights)

        self.config.task_input_pipelines = input_pipelines
        self.config.dataset_size = self.dataset_size_
        return val_funcs, train_dataset, val_sizes, val_intervals

    def _target_encoder(self):
        raise FinetuneError("This should never be used??")

def get_loss_logits_fn(task, featurizer_state, config, targets_i, train, reuse, task_id_i):
    def loss_logits():
        with tf.variable_scope("target_model_{}".format(task)):
            target_model_out = config.tasks[task]._target_model(
            logits = target_model_out["logits"]
            return target_model_out["losses"], logits

    return tf.equal(task_id_i, config.task_name_to_id[task]), loss_logits

[docs]class MultiTask(BaseModel): """ Target model for multi task learning. The approach used is to sample mini-batches from each task proportional to the size of the task for each dataset. :param tasks: A dictionary of pairs mapping string task names to model classes. eg. `{"sst": Classifier, "ner": SequenceLabeler}` :param \**kwargs: key-value pairs of config items to override. Note: The same config is used for each base task. """ def __init__(self, tasks, **kwargs): super().__init__(**kwargs) self.config.tasks = tasks self.config.task_name_to_id = dict(zip(self.config.tasks.keys(), range(len(self.config.tasks)))) def _get_input_pipeline(self): return MultiTaskPipeline(self.config)
[docs] def cached_predict(self): """ Context manager that prevents the recreation of the tensorflow graph on every call to BaseModel.predict(). Not supported for MultiTask. """ raise FinetuneError("cached_predict is not supported yet for MTL")
[docs] def featurize(self, X): """ Runs featurization on the trained model for any of the tasks the model was trained for. Input and output formats are the same as for each of the individial tasks. :param X: A dictionary mapping from task name to data, in the format required by the task type. :return: A dictionary mapping from task name to the features for that task. """ features = {} for name, ModelClass in self.config.tasks.items(): if name not in X: continue pred_model = ModelClass() pred_model.config = self.config pred_model.input_pipeline = self.config.task_input_pipelines[name] pred_model._initialize() pred_model.saver.variables = { k.replace("/target_model_{}".format(name), ""): v for k, v in self.saver.variables.items() } features[name] = pred_model.featurize(X[name]) return features
[docs] def predict(self, X): """ Runs inference on the trained model for any of the tasks the model was trained for. Input and output formats are the same as for each of the individial tasks. :param X: A dictionary mapping from task name to data, in the format required by the task type. :return: A dictionary mapping from task name to the predictions for that task. """ predictions = {} for name, ModelClass in self.config.tasks.items(): if name not in X: continue pred_model = ModelClass() pred_model.config = self.config pred_model.input_pipeline = self.config.task_input_pipelines[name] pred_model._initialize() pred_model.saver.variables = { k.replace("/target_model_{}".format(name), ""): v for k, v in self.saver.variables.items() } predictions[name] = pred_model.predict(X[name]) return predictions
[docs] def predict_proba(self, X): """ Runs probability inference on the trained model for any of the tasks the model was trained for. Falls back to normal predict when probabilities are not available for a task, eg Regression. Input and output formats are the same as for each of the individial tasks. :param X: A dictionary mapping from task name to data, in the format required by the task type. :return: A dictionary mapping from task name to the predictions for that task. """ predictions = {} for name, ModelClass in self.config.tasks.items(): if name not in X: continue pred_model = ModelClass() pred_model.config = self.config pred_model.input_pipeline = self.config.task_input_pipelines[name] pred_model._initialize() pred_model.saver.variables = { k.replace("/target_model_{}".format(name), ""): v for k, v in self.saver.variables.items() } try: predictions[name] = pred_model.predict_proba(X[name]) except FinetuneError as e: LOGGER.warning( ( "Probabilities are not available for {} and failed with exception {}." "Falling back to regular predictions for this task." ).format(name, e) ) predictions[name] = pred_model.predict(X[name]) return predictions
[docs] def finetune(self, X, Y=None, batch_size=None): """ :param X: A dictionary mapping from task name to inputs in the same format required for each of the models. :param Y: A dictionary mapping from task name to targets in the same format required for each of the models. :param batch_size: Number of examples per batch. When N_GPUS > 1, this number corresponds to the number of training examples provided to each GPU. :return: """ for t in [task_name for task_name, t in self.config.tasks.items() if t == SequenceLabeler]: X[t], Y[t], *_ = indico_to_finetune_sequence(X[t], labels=Y[t], multi_label=False, none_value="<PAD>") return super().finetune(X, Y=Y, batch_size=batch_size)
@staticmethod def _target_model(config, featurizer_state, targets, n_outputs, train=False, reuse=None, task_id=None, **kwargs): pred_fn_pairs = [] featurizer_state["features"] = tf.cond( tf.equal(tf.shape(featurizer_state["features"])[1], 1), true_fn=lambda: tf.squeeze(featurizer_state["features"], [1]), false_fn=lambda: featurizer_state["features"] ) targets_i = targets for task in config.tasks: pred, loss_logits = get_loss_logits_fn(task, featurizer_state, config, targets_i, train, reuse, task_id) pred_fn_pairs.append((pred, loss_logits)) losses_logits = pred_fn_pairs, default=None, exclusive=True, strict=True, name='top_selection' ) return { "logits": losses_logits[1], "losses": losses_logits[0] } def _predict_op(self, logits, **kwargs): return tf.no_op() def _predict_proba_op(self, logits, **kwargs): return tf.no_op()