Source code for malaya.model.huggingface
import tensorflow as tf
from herpetologist import check_type
from typing import List
[docs]class Generator:
def __init__(self, model, tokenizer, initial_text='', **kwargs):
self._model = model
self._tokenizer = tokenizer
self._initial_text = initial_text
[docs] @check_type
def generate(self, strings: List[str], **kwargs):
"""
Generate texts from the input.
Parameters
----------
strings : List[str]
**kwargs: vector arguments pass to huggingface `generate` method.
Returns
-------
result: List[str]
"""
input_ids = [{'input_ids': self._tokenizer.encode(f'{self._initial_text}{s}', return_tensors='tf')[
0]} for s in strings]
padded = self._tokenizer.pad(input_ids, padding='longest')
outputs = self._model.generate(**padded, **kwargs)
results = []
for o in outputs:
results.append(self._tokenizer.decode(o, skip_special_tokens=True))
return results