Source code for finetune.target_models.comparison

import numpy as np
import tensorflow as tf

from finetune.base import BaseModel
from finetune.target_models.classifier import Classifier, ClassificationPipeline
from finetune.encoding.input_encoder import ArrayEncodedOutput

class ComparisonPipeline(ClassificationPipeline):

    def _format_for_encoding(self, X):
        return [X]

    def _text_to_ids(self, pair, Y=None, pad_token=None):
        Format comparison examples as a list of IDs

        pairs: Array of text, shape [batch, 2]
        assert self.config.chunk_long_sequences is False, "Chunk Long Sequences is not compatible with comparison"
        arr_forward = next(super()._text_to_ids(pair, Y=None))
        reversed_pair = pair[::-1]
        arr_backward = next(super()._text_to_ids(reversed_pair, Y=None))
        kwargs = arr_forward._asdict()
        kwargs['tokens'] = [arr_forward.tokens, arr_backward.tokens]
        kwargs['token_ids'] = np.stack([arr_forward.token_ids, arr_backward.token_ids], 0)
        kwargs['mask'] = np.stack([arr_forward.mask, arr_backward.mask], 0)
        yield ArrayEncodedOutput(**kwargs)

    def feed_shape_type_def(self):
        TS = tf.TensorShape
        return ({"tokens": tf.int32, "mask": tf.float32}, tf.float32), (
            {"tokens": TS([2, self.config.max_length, 2]), "mask": TS([None, self.config.max_length])},

[docs]class Comparison(Classifier): """ Compares two documents to solve a classification task. :param config: A :py:class:`finetune.config.Settings` object or None (for default config). :param \**kwargs: key-value pairs of config items to override. """ def _get_input_pipeline(self): return ComparisonPipeline(self.config) @staticmethod def _target_model(config, *, featurizer_state, targets, n_outputs, train=False, reuse=None, **kwargs): featurizer_state = featurizer_state.copy() featurizer_state["sequence_features"] = tf.abs(tf.reduce_sum(featurizer_state["sequence_features"], 1)) featurizer_state["features"] = tf.abs(tf.reduce_sum(featurizer_state["features"], 1)) return Classifier._target_model(config, featurizer_state=featurizer_state, targets=targets, n_outputs=n_outputs, train=train, reuse=reuse, **kwargs)
[docs] def predict(self, pairs): """ Produces a list of most likely class labels as determined by the fine-tuned model. :param pairs: Array of text, shape [batch, 2] :returns: list of class labels. """ return BaseModel.predict(self, pairs)
[docs] def predict_proba(self, pairs): """ Produces a probability distribution over classes for each example in X. :param pairs: Array of text, shape [batch, 2] :returns: list of dictionaries. Each dictionary maps from a class label to its assigned class probability. """ return BaseModel.predict_proba(self, pairs)
[docs] def featurize(self, pairs): """ Embeds inputs in learned feature space. Can be called before or after calling :meth:`finetune`. :param pairs: Array of text, shape [batch, 2] :returns: np.array of features of shape (n_examples, embedding_size). """ return BaseModel.featurize(self, pairs)