{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "This tutorial is available as an IPython notebook at [Malaya/example/zeroshot-classification](https://github.com/huseinzol05/Malaya/tree/master/example/zeroshot-classification).\n", " \n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "This module trained on both standard and local (included social media) language structures, so it is save to use for both.\n", " \n", "
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n", "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/husein/.local/lib/python3.8/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", " warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/home/husein/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32\n", "CPU times: user 3.27 s, sys: 3.22 s, total: 6.5 s\n", "Wall time: 2.64 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/husein/dev/malaya/malaya/tokenizer.py:214: FutureWarning: Possible nested set at position 3397\n", " self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n", "/home/husein/dev/malaya/malaya/tokenizer.py:214: FutureWarning: Possible nested set at position 3927\n", " self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n" ] } ], "source": [ "%%time\n", "import malaya" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### what is zero-shot classification\n", "\n", "Commonly we supervised a machine learning on specific labels, negative / positive for sentiment, anger / happy / sadness for emotion and etc. The model cannot give an output if we want to know how much percentage of 'jealous' in emotion analysis model because supported labels are only {anger, happy, sadness}. Imagine, for example, trying to identify a text without ever having seen one 'jealous' label before, impossible. **So, zero-shot trying to solve this problem.**\n", "\n", "zero-shot learning refers to the process by which a machine learns how to recognize objects (image, text, any features) without any labeled training data to help in the classification.\n", "\n", "[Yin et al. (2019)](https://arxiv.org/abs/1909.00161) stated in his paper, any pretrained language model finetuned on text similarity actually can acted as an out-of-the-box zero-shot text classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, we are going to use transformer models from `malaya.similarity.semantic.huggingface` with a little tweaks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### List available HuggingFace models" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mesolitica/finetune-mnli-nanot5-small': {'Size (MB)': 148,\n", " 'macro precision': 0.87125,\n", " 'macro recall': 0.87131,\n", " 'macro f1-score': 0.87127},\n", " 'mesolitica/finetune-mnli-nanot5-base': {'Size (MB)': 892,\n", " 'macro precision': 0.78903,\n", " 'macro recall': 0.79064,\n", " 'macro f1-score': 0.78918}}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "malaya.zero_shot.classification.available_huggingface" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load HuggingFace model\n", "\n", "```python\n", "def huggingface(\n", " model: str = 'mesolitica/finetune-mnli-t5-small-standard-bahasa-cased',\n", " force_check: bool = True,\n", " **kwargs,\n", "):\n", " \"\"\"\n", " Load HuggingFace model to zeroshot text classification.\n", "\n", " Parameters\n", " ----------\n", " model: str, optional (default='mesolitica/finetune-mnli-t5-small-standard-bahasa-cased')\n", " Check available models at `malaya.zero_shot.classification.available_huggingface()`.\n", " force_check: bool, optional (default=True)\n", " Force check model one of malaya model.\n", " Set to False if you have your own huggingface model.\n", "\n", " Returns\n", " -------\n", " result: malaya.torch_model.huggingface.ZeroShotClassification\n", " \"\"\"\n", "```" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "model = malaya.zero_shot.classification.huggingface()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### predict batch\n", "\n", "```python\n", "def predict_proba(\n", " self,\n", " strings: List[str],\n", " labels: List[str],\n", " prefix: str = 'ayat ini berkaitan tentang',\n", " multilabel: bool = True,\n", "):\n", " \"\"\"\n", " classify list of strings and return probability.\n", "\n", " Parameters\n", " ----------\n", " strings: List[str]\n", " labels: List[str]\n", " prefix: str, optional (default='ayat ini berkaitan tentang')\n", " prefix of labels to zero shot. Playing around with prefix can get better results.\n", " multilabel: bool, optional (default=True)\n", " probability of labels can be more than 1.0\n", "```\n", "\n", "Because it is a zero-shot, we need to give labels for the model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# copy from twitter\n", "\n", "string = 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/plain": [ "[{'najib razak': 0.089086466,\n", " 'mahathir': 0.8503896,\n", " 'kerajaan': 0.31621307,\n", " 'PRU': 0.5521264,\n", " 'anarki': 0.018142236}]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki'])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "string = 'tolong order foodpanda jab, lapar'" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.9966216,\n", " 'makanan': 0.9912846,\n", " 'novel': 0.01200958,\n", " 'buku': 0.0026836568,\n", " 'kerajaan': 0.005800651,\n", " 'food delivery': 0.94829154}]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "the model understood `order foodpanda` got close relationship with `makan`, `makanan` and `food delivery`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.0023917605,\n", " 'makanan': 0.002768525,\n", " 'novel': 0.0035945452,\n", " 'buku': 0.0028883144,\n", " 'kerajaan': 0.9981665,\n", " 'food delivery': 0.0029965744,\n", " 'kerajaan jahat': 0.95778364,\n", " 'kerajaan prihatin': 0.9981933,\n", " 'bantuan rakyat': 0.99804246}]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### able to infer for mixed MS and EN" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "string = 'Hi guys! I noticed semalam & harini dah ramai yang dapat cookies ni kan. So harini i nak share some post mortem of our first batch:'" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.007691883,\n", " 'makanan': 0.997271,\n", " 'novel': 0.039510652,\n", " 'buku': 0.03565315,\n", " 'kerajaan': 0.0074525476,\n", " 'food delivery': 0.9393526,\n", " 'kerajaan jahat': 0.0053522647,\n", " 'kerajaan prihatin': 0.011083162,\n", " 'bantuan rakyat': 0.060150616,\n", " 'biskut': 0.9302781,\n", " 'very helpful': 0.07355973,\n", " 'sharing experiences': 0.9778896,\n", " 'sharing session': 0.014371477}]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n", " 'biskut', 'very helpful', 'sharing experiences',\n", " 'sharing session'])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.0014243807,\n", " 'makanan': 0.004838416,\n", " 'novel': 0.0019961353,\n", " 'buku': 0.003897282,\n", " 'kerajaan': 0.004189471,\n", " 'food delivery': 0.97480994,\n", " 'kerajaan jahat': 0.0018161167,\n", " 'kerajaan prihatin': 0.0054033417,\n", " 'bantuan rakyat': 0.0054734466,\n", " 'biskut': 0.018219633,\n", " 'very helpful': 0.03659028,\n", " 'sharing experiences': 0.98463523,\n", " 'sharing session': 0.013350475}]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n", " 'biskut', 'very helpful', 'sharing experiences',\n", " 'sharing session'],\n", " prefix = 'teks ini berkaitan tentang')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiclasses but not multilabel\n", "\n", "Sum of probability equal to 1.0, so to do that, set `multilabel=False`." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.00062935066,\n", " 'makanan': 0.00067746383,\n", " 'novel': 0.0007715335,\n", " 'buku': 0.0006922778,\n", " 'kerajaan': 0.2833456,\n", " 'food delivery': 0.0007045073,\n", " 'kerajaan jahat': 0.05875754,\n", " 'kerajaan prihatin': 0.28552753,\n", " 'bantuan rakyat': 0.27457199,\n", " 'biskut': 0.0007160352,\n", " 'very helpful': 0.09099287,\n", " 'sharing experiences': 0.0012673552,\n", " 'sharing session': 0.0013456849}]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n", "\n", "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n", " 'biskut', 'very helpful', 'sharing experiences',\n", " 'sharing session'], multilabel = False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stacking models\n", "\n", "More information, you can read at https://malaya.readthedocs.io/en/latest/Stack.html\n", "\n", "If you want to stack zero-shot classification models, you need to pass labels using keyword parameter,\n", "\n", "```python\n", "malaya.stack.predict_stack([model1, model2], List[str], labels = List[str])\n", "```\n", "\n", "We will passed `labels` as `**kwargs`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.0023917593,\n", " 'makanan': 0.002768525,\n", " 'novel': 0.0035945452,\n", " 'buku': 0.0028883128,\n", " 'kerajaan': 0.9981665,\n", " 'food delivery': 0.0029965725,\n", " 'kerajaan jahat': 0.95778376,\n", " 'kerajaan prihatin': 0.9981934,\n", " 'bantuan rakyat': 0.9980425,\n", " 'comel': 0.0031943405,\n", " 'kerajaan syg sgt kepada rakyat': 0.99586475}]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n", "labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery', \n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat', 'comel', 'kerajaan syg sgt kepada rakyat']\n", "malaya.stack.predict_stack([model, model, model], [string], \n", " labels = labels)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }