Source code for malaya.model.pegasus

import tensorflow as tf
from malaya.text.rouge import postprocess_summary
from malaya.text.function import summarization_textcleaning
from malaya.model.abstract import Seq2Seq
from herpetologist import check_type
from typing import List

pad_sequences = tf.keras.preprocessing.sequence.pad_sequences


[docs]class Summarization(Seq2Seq): def __init__(self, input_nodes, output_nodes, sess, tokenizer): self._input_nodes = input_nodes self._output_nodes = output_nodes self._sess = sess self._tokenizer = tokenizer def _summarize( self, strings, top_p=0.7, temperature=1.0, postprocess=True, **kwargs, ): strings_ = [summarization_textcleaning(string) for string in strings] batch_x = [self._tokenizer.encode(string) + [1] for string in strings_] batch_x = pad_sequences(batch_x, padding='post') r = self._execute( inputs=[batch_x, top_p, temperature], input_labels=['Placeholder', 'top_p', 'temperature'], output_labels=['logits'], ) p = r['logits'].tolist() results = [] for no, r in enumerate(p): summary = self._tokenizer.decode(r) if postprocess: summary = postprocess_summary(strings[no], summary, **kwargs) results.append(summary) return results
[docs] @check_type def greedy_decoder( self, strings: List[str], temperature: float = 0.0, postprocess: bool = False, **kwargs, ): """ Summarize strings using greedy decoder. Parameters ---------- strings: List[str] temperature: float, (default=0.3) logits * -log(random.uniform) * temperature. postprocess: bool, optional (default=False) If True, will filter sentence generated using ROUGE score and removed international news publisher. Returns ------- result: List[str] """ return self._summarize( strings=strings, top_p=0.0, temperature=temperature, postprocess=postprocess, **kwargs, )
[docs] @check_type def nucleus_decoder( self, strings: List[str], top_p: float = 0.7, temperature: float = 0.2, postprocess: bool = False, **kwargs, ): """ Summarize strings using nucleus decoder. Parameters ---------- strings: List[str] top_p: float, (default=0.7) cumulative distribution and cut off as soon as the CDF exceeds `top_p`. temperature: float, (default=0.3) logits * -log(random.uniform) * temperature. postprocess: bool, optional (default=False) If True, will filter sentence generated using ROUGE score and removed international news publisher. Returns ------- result: List[str] """ return self._summarize( strings=strings, top_p=top_p, temperature=temperature, postprocess=postprocess, **kwargs, )