{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification HuggingFace" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "This tutorial is available as an IPython notebook at [Malaya/example/zeroshot-classification-huggingface](https://github.com/huseinzol05/Malaya/tree/master/example/zeroshot-classification-huggingface).\n", " \n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "This module trained on both standard and local (included social media) language structures, so it is save to use for both.\n", " \n", "
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n", "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import logging\n", "\n", "logging.basicConfig(level=logging.INFO)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.1 s, sys: 3.47 s, total: 6.57 s\n", "Wall time: 2.28 s\n" ] } ], "source": [ "%%time\n", "import malaya" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### what is zero-shot classification\n", "\n", "Commonly we supervised a machine learning on specific labels, negative / positive for sentiment, anger / happy / sadness for emotion and etc. The model cannot give an output if we want to know how much percentage of 'jealous' in emotion analysis model because supported labels are only {anger, happy, sadness}. Imagine, for example, trying to identify a text without ever having seen one 'jealous' label before, impossible. **So, zero-shot trying to solve this problem.**\n", "\n", "zero-shot learning refers to the process by which a machine learns how to recognize objects (image, text, any features) without any labeled training data to help in the classification.\n", "\n", "[Yin et al. (2019)](https://arxiv.org/abs/1909.00161) stated in his paper, any pretrained language model finetuned on text similarity actually can acted as an out-of-the-box zero-shot text classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, we are going to use transformer models from `malaya.similarity.semantic.huggingface` with a little tweaks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### List available HuggingFace models" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:malaya.similarity.semantic:tested on matched dev set translated MNLI, https://huggingface.co/datasets/mesolitica/translated-MNLI\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Size (MB)macro precisionmacro recallmacro f1-score
mesolitica/finetune-mnli-t5-super-tiny-standard-bahasa-cased50.70.745620.745740.74501
mesolitica/finetune-mnli-t5-tiny-standard-bahasa-cased139.00.765840.765650.76542
mesolitica/finetune-mnli-t5-small-standard-bahasa-cased242.00.780670.780630.78010
mesolitica/finetune-mnli-t5-base-standard-bahasa-cased892.00.789030.790640.78918
\n", "
" ], "text/plain": [ " Size (MB) \\\n", "mesolitica/finetune-mnli-t5-super-tiny-standard... 50.7 \n", "mesolitica/finetune-mnli-t5-tiny-standard-bahas... 139.0 \n", "mesolitica/finetune-mnli-t5-small-standard-baha... 242.0 \n", "mesolitica/finetune-mnli-t5-base-standard-bahas... 892.0 \n", "\n", " macro precision \\\n", "mesolitica/finetune-mnli-t5-super-tiny-standard... 0.74562 \n", "mesolitica/finetune-mnli-t5-tiny-standard-bahas... 0.76584 \n", "mesolitica/finetune-mnli-t5-small-standard-baha... 0.78067 \n", "mesolitica/finetune-mnli-t5-base-standard-bahas... 0.78903 \n", "\n", " macro recall \\\n", "mesolitica/finetune-mnli-t5-super-tiny-standard... 0.74574 \n", "mesolitica/finetune-mnli-t5-tiny-standard-bahas... 0.76565 \n", "mesolitica/finetune-mnli-t5-small-standard-baha... 0.78063 \n", "mesolitica/finetune-mnli-t5-base-standard-bahas... 0.79064 \n", "\n", " macro f1-score \n", "mesolitica/finetune-mnli-t5-super-tiny-standard... 0.74501 \n", "mesolitica/finetune-mnli-t5-tiny-standard-bahas... 0.76542 \n", "mesolitica/finetune-mnli-t5-small-standard-baha... 0.78010 \n", "mesolitica/finetune-mnli-t5-base-standard-bahas... 0.78918 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "malaya.zero_shot.classification.available_huggingface()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load HuggingFace model\n", "\n", "```python\n", "def huggingface(model: str = 'mesolitica/finetune-mnli-t5-small-standard-bahasa-cased', **kwargs):\n", " \"\"\"\n", " Load HuggingFace model to zeroshot text classification.\n", "\n", " Parameters\n", " ----------\n", " model: str, optional (default='mesolitica/finetune-mnli-t5-small-standard-bahasa-cased')\n", " Check available models at `malaya.zero_shot.classification.available_huggingface()`.\n", "\n", " Returns\n", " -------\n", " result: malaya.torch_model.huggingface.ZeroShotClassification\n", " \"\"\"\n", "```" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "model = malaya.zero_shot.classification.huggingface()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### predict batch\n", "\n", "```python\n", "def predict_proba(\n", " self,\n", " strings: List[str],\n", " labels: List[str],\n", " prefix: str = 'ayat ini berkaitan tentang',\n", " multilabel: bool = True,\n", "):\n", " \"\"\"\n", " classify list of strings and return probability.\n", "\n", " Parameters\n", " ----------\n", " strings: List[str]\n", " labels: List[str]\n", " prefix: str, optional (default='ayat ini berkaitan tentang')\n", " prefix of labels to zero shot. Playing around with prefix can get better results.\n", " multilabel: bool, optional (default=True)\n", " probability of labels can be more than 1.0\n", "```\n", "\n", "Because it is a zero-shot, we need to give labels for the model." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# copy from twitter\n", "\n", "string = 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/plain": [ "[{'najib razak': 0.6651765,\n", " 'mahathir': 0.987833,\n", " 'kerajaan': 0.9912515,\n", " 'PRU': 0.9841426,\n", " 'anarki': 0.45587578}]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki'])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "string = 'tolong order foodpanda jab, lapar'" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.9698464,\n", " 'makanan': 0.9735605,\n", " 'novel': 0.19823082,\n", " 'buku': 0.00313239,\n", " 'kerajaan': 0.12976034,\n", " 'food delivery': 0.99331254}]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "the model understood `order foodpanda` got close relationship with `makan`, `makanan` and `food delivery`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.0004689095,\n", " 'makanan': 0.0026079589,\n", " 'novel': 0.29850212,\n", " 'buku': 0.025044106,\n", " 'kerajaan': 0.76523817,\n", " 'food delivery': 0.0044676424,\n", " 'kerajaan jahat': 0.0023713536,\n", " 'kerajaan prihatin': 0.9468328,\n", " 'bantuan rakyat': 0.9923975}]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### able to infer for mixed MS and EN" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "string = 'Hi guys! I noticed semalam & harini dah ramai yang dapat cookies ni kan. So harini i nak share some post mortem of our first batch:'" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.17769064,\n", " 'makanan': 0.94145703,\n", " 'novel': 0.51651853,\n", " 'buku': 0.21957111,\n", " 'kerajaan': 0.11726684,\n", " 'food delivery': 0.903062,\n", " 'kerajaan jahat': 0.33357194,\n", " 'kerajaan prihatin': 0.14763993,\n", " 'bantuan rakyat': 0.5784646,\n", " 'biskut': 0.8355128,\n", " 'very helpful': 0.39513826,\n", " 'sharing experiences': 0.64116335,\n", " 'sharing session': 0.675511}]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n", " 'biskut', 'very helpful', 'sharing experiences',\n", " 'sharing session'])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.23804268,\n", " 'makanan': 0.94474393,\n", " 'novel': 0.8238379,\n", " 'buku': 0.3343829,\n", " 'kerajaan': 0.092507444,\n", " 'food delivery': 0.94236046,\n", " 'kerajaan jahat': 0.15810412,\n", " 'kerajaan prihatin': 0.13604635,\n", " 'bantuan rakyat': 0.55307525,\n", " 'biskut': 0.92333925,\n", " 'very helpful': 0.39841577,\n", " 'sharing experiences': 0.7563246,\n", " 'sharing session': 0.86674726}]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n", " 'biskut', 'very helpful', 'sharing experiences',\n", " 'sharing session'],\n", " prefix = 'teks ini berkaitan tentang')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiclasses but not multilabel\n", "\n", "Sum of probability equal to 1.0, so to do that, set `multilabel=False`." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.0013036507,\n", " 'makanan': 0.0012489067,\n", " 'novel': 0.007235752,\n", " 'buku': 0.0022450346,\n", " 'kerajaan': 0.070251726,\n", " 'food delivery': 0.0042558503,\n", " 'kerajaan jahat': 0.0022728115,\n", " 'kerajaan prihatin': 0.20736308,\n", " 'bantuan rakyat': 0.57145786,\n", " 'biskut': 0.0020565772,\n", " 'very helpful': 0.11333891,\n", " 'sharing experiences': 0.007458821,\n", " 'sharing session': 0.00951114}]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n", "\n", "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n", " 'biskut', 'very helpful', 'sharing experiences',\n", " 'sharing session'], multilabel = False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stacking models\n", "\n", "More information, you can read at https://malaya.readthedocs.io/en/latest/Stack.html\n", "\n", "If you want to stack zero-shot classification models, you need to pass labels using keyword parameter,\n", "\n", "```python\n", "malaya.stack.predict_stack([model1, model2], List[str], labels = List[str])\n", "```\n", "\n", "We will passed `labels` as `**kwargs`." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'makan': 0.00046890916,\n", " 'makanan': 0.0026079628,\n", " 'novel': 0.29850233,\n", " 'buku': 0.02504399,\n", " 'kerajaan': 0.7652382,\n", " 'food delivery': 0.004467653,\n", " 'kerajaan jahat': 0.0023713524,\n", " 'kerajaan prihatin': 0.9468329,\n", " 'bantuan rakyat': 0.99239755,\n", " 'comel': 0.00077307917,\n", " 'kerajaan syg sgt kepada rakyat': 0.9818335}]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n", "labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery', \n", " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat', 'comel', 'kerajaan syg sgt kepada rakyat']\n", "malaya.stack.predict_stack([model, model, model], [string], \n", " labels = labels)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }