Source code for malaya.torch_model.llm

import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
import logging

logger = logging.getLogger(__name__)

try:
    from transformers import TextIteratorStreamer
except BaseException:
    TextIteratorStreamer = None


[docs]class LLM: def __init__( self, model, use_ctranslate2=False, **kwargs, ): self.model_name = model self.tokenizer = AutoTokenizer.from_pretrained(model, **kwargs) self.use_ctranslate2 = use_ctranslate2 if self.use_ctranslate2: from malaya_boilerplate.converter import ctranslate2_generator self.model = ctranslate2_generator(model, **kwargs) else: self.model = AutoModelForCausalLM.from_pretrained(model, **kwargs) def _get_input_llama( self, query, keys={'role', 'content'}, roles={'system', 'user', 'assistant'}, ): """ We follow exactly Llama2 chatbot tokens. https://github.com/facebookresearch/llama/blob/main/llama/generation.py """ if not len(query): raise ValueError('`query` must length as least 2.') for i in range(len(query)): for k in keys: if k not in query[i]: raise ValueError(f'{i} message does not have `{k}` key.') if query[i]['role'] not in roles: raise ValueError(f'Only accept `{roles}` for {i} message role.') if query[0]['role'] != 'system': raise ValueError('first message of `query` must `system` role.') if query[-1]['role'] != 'user': raise ValueError('last message of `query` must `user` role.') system = query[0]['content'] user_query = query[-1]['content'] users, assistants = [], [] for q in query[1:-1]: if q['role'] == 'user': users.append(q['content']) elif q['role'] == 'assistant': assistants.append(q['content']) if len(users) != len(assistants): raise ValueError( 'model only support `system, `user` and `assistant` roles, starting with `system`, then `user` and alternating (u/a/u/a/u...)') texts = [f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'] for u, a in zip(users, assistants): texts.append(f'{u.strip()} [/INST] {a.strip()} </s><s>[INST] ') texts.append(f'{user_query.strip()} [/INST]') prompt = ''.join(texts) logger.debug(f'prompt: {prompt}') if self.use_ctranslate2: return self.tokenizer.tokenize(prompt) else: inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors='pt') return inputs['input_ids'].to(self.model.device) def _get_input_mistral( self, query, keys={'role', 'content'}, roles={'user', 'assistant'}, ): """ We follow exactly Llama2 chatbot tokens. https://github.com/facebookresearch/llama/blob/main/llama/generation.py """ if not len(query): raise ValueError('`query` must length as least 2.') for i in range(len(query)): for k in keys: if k not in query[i]: raise ValueError(f'{i} message does not have `{k}` key.') if query[i]['role'] not in roles: raise ValueError(f'Only accept `{roles}` for {i} message role.') if query[0]['role'] != 'system': raise ValueError('first message of `query` must `system` role.') if query[-1]['role'] != 'user': raise ValueError('last message of `query` must `user` role.') system = query[0]['content'] user_query = query[-1]['content'] users, assistants = [], [] for q in query[1:-1]: if q['role'] == 'user': users.append(q['content']) elif q['role'] == 'assistant': assistants.append(q['content']) if len(users) != len(assistants): raise ValueError( 'model only support `system, `user` and `assistant` roles, starting with `system`, then `user` and alternating (u/a/u/a/u...)') texts = [f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'] for u, a in zip(users, assistants): texts.append(f'{u.strip()} [/INST] {a.strip()} </s><s>[INST] ') texts.append(f'{user_query.strip()} [/INST]') prompt = ''.join(texts) logger.debug(f'prompt: {prompt}') if self.use_ctranslate2: return self.tokenizer.tokenize(prompt) else: inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors='pt') return inputs['input_ids'].to(self.model.device) def _get_input(self, query): if 'llama2' in self.model_name: return self._get_input_llama(query=query) elif 'mistral' in self.model_name: return self._get_input_mistral(query=query) else: raise ValueError('Currently we only support Llama2 and Mistral based.')
[docs] def generate( self, query: List[Dict[str, str]], **kwargs, ): """ Generate respond from user input. Parameters ---------- query: List[Dict[str, str]] [ { 'role': 'system', 'content': 'anda adalah AI yang berguna', }, { 'role': 'user', 'content': 'makanan apa yang sedap?', } ] **kwargs: vector arguments pass to huggingface `generate` method. Read more at https://huggingface.co/docs/transformers/main_classes/text_generation If you are using `use_ctranslate2`, vector arguments pass to ctranslate2 `translate_batch` method. Read more at https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html?highlight=translate_batch#ctranslate2.Translator.translate_batch Returns ------- result: str """ input_ids = self._get_input(query=query) if self.use_ctranslate2: o = model.generate_batch([input_ids], **kwargs) output = tokenizer.decode(tokenizer.convert_tokens_to_ids(o.sequences[0])) else: with torch.no_grad(): generation_output = self.model.generate( input_ids=input_ids, return_dict_in_generate=False, output_scores=False, **kwargs, ) s = generation_output[0] output = self.tokenizer.decode(s, skip_special_tokens=True) return output
[docs] def generate_stream( self, query: List[Dict[str, str]], **kwargs, ): """ Generate respond from user input in streaming mode. Parameters ---------- query: str User input. **kwargs: vector arguments pass to huggingface `generate` method. Read more at https://huggingface.co/docs/transformers/main_classes/text_generation Returns ------- result: str """ input_ids = self._get_input(query=query) if self.use_ctranslate2: step_results = self.model.generate_tokens( input_ids, **kwargs ) tokens_buffer = [] text_output = '' for step_result in step_results: is_new_word = step_result.token.startswith('▁') if is_new_word and tokens_buffer: word = tokenizer.decode(tokens_buffer) if word: text_output += ' ' + word yield text_output tokens_buffer = [] tokens_buffer.append(step_result.token_id) if tokens_buffer: word = tokenizer.decode(tokens_buffer) if word: text_output += ' ' + word yield text_output else: if TextIteratorStreamer is None: raise ModuleNotFoundError( '`transformers.TextIteratorStreamer` is not available, please install latest transformers library.' ) streamer = TextIteratorStreamer( self.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True, ) generate_kwargs = dict( {'input_ids': input_ids}, streamer=streamer, **kwargs, ) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield ''.join(outputs)