Source code for malaya.augmentation
import random
import json
import inspect
import numpy as np
import string as string_function
from collections import defaultdict
from malaya.function import check_file
from tensorflow.keras.preprocessing.sequence import pad_sequences
from malaya.text.tatabahasa import consonants, vowels
from malaya.text.function import augmentation_textcleaning, simple_textcleaning
from malaya.path import PATH_AUGMENTATION, S3_PATH_AUGMENTATION
from herpetologist import check_type
from typing import Callable
_synonym_dict = None
def to_ids(string, tokenizer):
words = []
for no, word in enumerate(string):
if word == '[MASK]':
words.append(word)
else:
words.extend(tokenizer.tokenize(word))
masked_tokens = ['[CLS]'] + words + ['[SEP]']
masked_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
return masked_ids, masked_ids.index(tokenizer.vocab['[MASK]'])
def replace(string, threshold):
for no, word in enumerate(string):
if word in _synonym_dict and random.random() > threshold:
w = random.choice(_synonym_dict[word])
string[no] = w
return string
def _make_upper(p, o):
p_split = p.split()
o_split = o.split()
return ' '.join(
[
s.title() if o_split[no][0].isupper() else s
for no, s in enumerate(p_split)
]
)
[docs]@check_type
def synonym(
string: str,
threshold: float = 0.5,
top_n=5,
cleaning=augmentation_textcleaning,
**kwargs
):
"""
augmenting a string using synonym, https://github.com/huseinzol05/Malaya-Dataset#90k-synonym
Parameters
----------
string: str
threshold: float, optional (default=0.5)
random selection for a word.
top_n: int, (default=5)
number of nearest neighbors returned. Length of returned result should as top_n.
cleaning: function, (default=malaya.text.function.augmentation_textcleaning)
function to clean text.
Returns
-------
result: List[str]
"""
if not isinstance(cleaning, Callable) and cleaning is not None:
raise ValueError('cleaning must be a callable type or None')
global _synonym_dict
if _synonym_dict is None:
path = check_file(
PATH_AUGMENTATION['synonym'],
S3_PATH_AUGMENTATION['synonym'],
**kwargs
)
files = list(path.values())
synonyms = defaultdict(list)
for file in files:
with open(file) as fopen:
data = json.load(fopen)
for i in data:
if not len(i[1]):
continue
synonyms[i[0]].extend(i[1])
for r in i[1]:
synonyms[r].append(i[0])
for k, v in synonyms.items():
synonyms[k] = list(set(v))
_synonym_dict = synonyms
original_string = string
if cleaning:
string = cleaning(string).split()
augmented = []
for i in range(top_n):
string_ = replace(string, threshold)
augmented.append(
_make_upper(' '.join(string_), ' '.join(original_string))
)
return augmented
[docs]@check_type
def wordvector(
string: str,
wordvector,
threshold: float = 0.5,
top_n: int = 5,
soft: bool = False,
cleaning=augmentation_textcleaning,
):
"""
augmenting a string using wordvector.
Parameters
----------
string: str
wordvector: object
wordvector interface object.
threshold: float, optional (default=0.5)
random selection for a word.
soft: bool, optional (default=False)
if True, a word not in the dictionary will be replaced with nearest jarowrinkler ratio.
if False, it will throw an exception if a word not in the dictionary.
top_n: int, (default=5)
number of nearest neighbors returned. Length of returned result should as top_n.
cleaning: function, (default=malaya.text.function.augmentation_textcleaning)
function to clean text.
Returns
-------
result: List[str]
"""
if not isinstance(cleaning, Callable) and cleaning is not None:
raise ValueError('cleaning must be a callable type or None')
if not hasattr(wordvector, 'batch_n_closest'):
raise ValueError('wordvector must have `batch_n_closest` method')
if not hasattr(wordvector, '_dictionary'):
raise ValueError('wordvector must have `_dictionary` attribute')
from malaya.preprocessing import _tokenizer
original_string = string
if cleaning:
string = cleaning(string)
string = _tokenizer(string)
original_string = string[:]
selected = []
for no, w in enumerate(string):
if w in string_function.punctuation:
continue
if w[0].isupper():
continue
if random.random() > threshold:
selected.append((no, w))
if not len(selected):
raise ValueError(
'no words can augmented, make sure words available are not punctuation or proper nouns.'
)
indices, words = [i[0] for i in selected], [i[1] for i in selected]
batch_parameters = list(
inspect.signature(wordvector.batch_n_closest).parameters.keys()
)
if 'soft' in batch_parameters:
results = wordvector.batch_n_closest(
words, num_closest=top_n, soft=soft
)
else:
results = wordvector.batch_n_closest(words, num_closest=top_n)
augmented = []
for i in range(top_n):
string_ = string[:]
for no in range(len(results)):
string_[indices[no]] = results[no][i]
augmented.append(
_make_upper(' '.join(string_), ' '.join(original_string))
)
return augmented
[docs]@check_type
def transformer(
string: str,
model,
threshold: float = 0.5,
top_p: float = 0.9,
top_k: int = 100,
temperature: float = 1.0,
top_n: int = 5,
cleaning=None,
):
"""
augmenting a string using transformer + nucleus sampling / top-k sampling.
Parameters
----------
string: str
model: object
transformer interface object. Right now only supported BERT, ALBERT and ELECTRA.
threshold: float, optional (default=0.5)
random selection for a word.
top_p: float, optional (default=0.8)
cumulative sum of probabilities to sample a word.
If top_n bigger than 0, the model will use nucleus sampling, else top-k sampling.
top_k: int, optional (default=100)
k for top-k sampling.
temperature: float, optional (default=0.8)
logits * temperature.
top_n: int, (default=5)
number of nearest neighbors returned. Length of returned result should as top_n.
cleaning: function, (default=None)
function to clean text.
Returns
-------
result: List[str]
"""
if not isinstance(cleaning, Callable) and cleaning is not None:
raise ValueError('cleaning must be a callable type or None')
if not hasattr(model, 'samples'):
raise ValueError('model must have `samples` attribute')
if not (threshold > 0 and threshold < 1):
raise ValueError('threshold must be bigger than 0 and less than 1')
if not top_p > 0:
raise ValueError('top_p must be bigger than 0')
if not top_k > 0:
raise ValueError('top_k must be bigger than 0')
if not 0 < temperature <= 1.0:
raise ValueError('temperature must, 0 < temperature <= 1.0')
if not top_n > 0:
raise ValueError('top_n must be bigger than 0')
if top_n > top_k:
raise ValueError('top_k must be bigger than top_n')
from malaya.preprocessing import _tokenizer
original_string = string
if cleaning:
string = cleaning(string)
string = _tokenizer(string)
results = []
for token_idx, token in enumerate(string):
if token in string_function.punctuation:
continue
if token[0].isupper():
continue
if token.isdigit():
continue
if random.random() > threshold:
results.append(token_idx)
if not len(results):
raise ValueError(
'no words can augmented, make sure words available are not punctuation or proper nouns.'
)
maskeds, indices, input_masks, input_segments = [], [], [], []
for index in results:
new = string[:]
new[index] = '[MASK]'
mask, ind = to_ids(new, model._tokenizer)
maskeds.append(mask)
indices.append(ind)
input_masks.append([1] * len(mask))
input_segments.append([0] * len(mask))
masked_padded = pad_sequences(maskeds, padding='post')
input_masks = pad_sequences(input_masks, padding='post')
input_segments = pad_sequences(input_segments, padding='post')
batch_indices = np.array([np.arange(len(indices)), indices]).T
samples = model._sess.run(
model.samples,
feed_dict={
model.X: masked_padded,
model.MASK: input_masks,
model.top_p: top_p,
model.top_k: top_k,
model.temperature: temperature,
model.indices: batch_indices,
model.k: top_n,
model.segment_ids: input_segments,
},
)
outputs = []
for i in range(samples.shape[1]):
sample_i = samples[:, i]
samples_tokens = model._tokenizer.convert_ids_to_tokens(
sample_i.tolist()
)
if hasattr(model._tokenizer, 'sp_model'):
new_splitted = ['▁' + w if len(w) > 1 else w for w in string]
else:
new_splitted = [w if len(w) > 1 else w for w in string]
for no, index in enumerate(results):
new_splitted[index] = samples_tokens[no]
if hasattr(model._tokenizer, 'sp_model'):
new = ''.join(model._tokenizer.sp_model.DecodePieces(new_splitted))
else:
new = ' '.join(new_splitted)
outputs.append(new)
return outputs
def _replace(word, replace_dict, threshold=0.5):
word = list(word[:])
for i in range(len(word)):
if word[i] in replace_dict and random.random() >= threshold:
word[i] = replace_dict[word[i]]
return ''.join(word)
[docs]def replace_similar_consonants(word: str, threshold: float = 0.8):
"""
Naively replace consonants into similar consonants in a word.
Parameters
----------
word: str
threshold: float, optional (default=0.8)
Returns
-------
result: List[str]
"""
replace_consonants = {
'n': 'm',
't': 'y',
'r': 't',
'g': 'h',
'j': 'k',
'k': 'l',
'd': 's',
'd': 'f',
'g': 'f',
'b': 'n',
}
return _replace(word=word, replace_dict=replace_consonants, threshold=threshold)
[docs]def replace_similar_vowels(word: str, threshold: float = 0.8):
"""
Naively replace vowels into similar vowels in a word.
Parameters
----------
word: str
threshold: float, optional (default=0.8)
Returns
-------
result: List[str]
"""
replace_vowels = {'u': 'i', 'i': 'o', 'o': 'u'}
return _replace(word=word, replace_dict=replace_vowels, threshold=threshold)
[docs]@check_type
def socialmedia_form(word: str):
"""
augmenting a word into socialmedia form.
Parameters
----------
word: str
Returns
-------
result: List[str]
"""
word = simple_textcleaning(word)
if not len(word):
raise ValueError('word is too short to augment shortform.')
results = []
if len(word) > 1:
if word[-1] == 'a' and word[-2] in consonants:
results.append(word[:-1] + 'e')
if word[0] == 'f' and word[-1] == 'r':
results.append('p' + word[1:])
if word[-2] in consonants and word[-1] in vowels:
results.append(word + 'k')
if word[-2] in vowels and word[-1] == 'h':
results.append(word[:-1])
if len(word) > 2:
if word[-3] in consonants and word[-2:] == 'ar':
results.append(word[:-2] + 'o')
if word[0] == 'h' and word[1] in vowels and word[2] in consonants:
results.append(word[1:])
if word[-3] in consonants and word[-2:] == 'ng':
results.append(word[:-2] + 'g')
if word[1:3] == 'ng':
results.append(word[:1] + x[2:])
return list(set(results))
[docs]def vowel_alternate(word: str, threshold: float = 0.5):
"""
augmenting a word into vowel alternate.
vowel_alternate('singapore')
-> sngpore
vowel_alternate('kampung')
-> kmpng
vowel_alternate('ayam')
-> aym
Parameters
----------
word: str
threshold: float, optional (default=0.5)
Returns
-------
result: str
"""
word = simple_textcleaning(word)
if not len(word):
raise ValueError('word is too short to augment shortform.')
word = list(word[:])
i = 0
while i < len(word) - 2:
subword = word[i: i + 3]
if subword[0] in consonants and subword[1] in vowels and subword[2] in consonants \
and random.random() >= threshold:
word.pop(i + 1)
i += 1
return ''.join(word)