{ "cells": [ { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['id', 'gender', 'masterCategory', 'subCategory', 'articleType', 'baseColour', 'season', 'year', 'usage', 'productDisplayName', 'image'],\n", " num_rows: 44072\n", "})" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load dataset\n", "import numpy\n", "\n", "# Load the dataset from huggingface datasets hub\n", "'''\n", "from datasets import load_dataset\n", "fashion = load_dataset(\n", " \"ashraq/fashion-product-images-small\",\n", " split=\"train\"\n", ")\n", "'''\n", "\n", "# Load from local path\n", "from datasets import load_from_disk\n", "dataset = load_from_disk(\"../fashion-product-images-small/dataset/\")\n", "dataset" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
gendermasterCategorysubCategoryarticleTypebaseColourseasonyearusage
0MenApparelTopwearShirtsNavy BlueFall2011.0Casual
1MenApparelBottomwearJeansBlueSummer2012.0Casual
2WomenAccessoriesWatchesWatchesSilverWinter2016.0Casual
3MenApparelBottomwearTrack PantsBlackFall2011.0Casual
4MenApparelTopwearTshirtsGreySummer2012.0Casual
\n", "
" ], "text/plain": [ " gender masterCategory subCategory articleType baseColour season year \n", "0 Men Apparel Topwear Shirts Navy Blue Fall 2011.0 \\\n", "1 Men Apparel Bottomwear Jeans Blue Summer 2012.0 \n", "2 Women Accessories Watches Watches Silver Winter 2016.0 \n", "3 Men Apparel Bottomwear Track Pants Black Fall 2011.0 \n", "4 Men Apparel Topwear Tshirts Grey Summer 2012.0 \n", "\n", " usage \n", "0 Casual \n", "1 Casual \n", "2 Casual \n", "3 Casual \n", "4 Casual " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# extract ids, texts, images\n", "ids = dataset[\"id\"]\n", "texts = dataset[\"productDisplayName\"]\n", "images = dataset[\"image\"]\n", "pos = numpy.arange(len(ids))\n", "id_pos = dict(zip(ids, pos))\n", "# convert metadata into a pandas dataframe\n", "dataset = dataset.remove_columns([\"id\", \"productDisplayName\", \"image\"]).to_pandas()\n", "dataset.head()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "{'lexical_term_count': '8296',\n", " 'lexical_record_count': '44072',\n", " 'lexical_algorithm': 'bm25',\n", " 'ef_construct': '100',\n", " 'current_record_count': '44072',\n", " 'delete_record_count': '0',\n", " 'distance_method': 'COSINE',\n", " 'M': '16',\n", " 'algorithm': 'HNSW',\n", " 'data_type': 'FLOAT32',\n", " 'attribute_data_size': '32229974',\n", " 'index_data_size': '155131952',\n", " 'dimension': '512',\n", " 'data_count': '44072'}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Confi Env\n", "\n", "# Load a CLIP model from huggingface\n", "from sentence_transformers import SentenceTransformer\n", "import torch, os\n", "model = SentenceTransformer(\n", " 'sentence-transformers/clip-ViT-B-32',\n", " device='cuda' if torch.cuda.is_available() else 'cpu'\n", " )\n", "os.environ[\"TAIR_VECTOR_ENDPOINT\"] = \"r-bp1cu1a7hatj2c****.redis.rds.aliyuncs.com:6379\"\n", "from tair import TairCluster as Tair\n", "tair_vector_endpoint = os.getenv(\"TAIR_VECTOR_ENDPOINT\", \"redis://user:passwd@r-bp1cu1a7hatj2c****.redis.rds.aliyuncs.com:6379\")\n", "dimension = len(model.encode([images[0]])[0])\n", "index_name = \"hybrid_index_3\"\n", "distance_type = \"cosine\"\n", "index_type = \"HNSW\"\n", "\n", "client = Tair.from_url(tair_vector_endpoint)\n", "if client.tvs_get_index(index_name) is None:\n", " kwargs = {\"lexical_algorithm\":\"bm25\", \"hybrid_ratio\":0.5}\n", " client.tvs_create_index(index_name, dimension, distance_type, index_type, **kwargs)\n", "client.tvs_get_index(index_name)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "{'lexical_term_count': '8296',\n", " 'lexical_record_count': '44072',\n", " 'lexical_algorithm': 'bm25',\n", " 'ef_construct': '100',\n", " 'current_record_count': '44072',\n", " 'delete_record_count': '0',\n", " 'distance_method': 'COSINE',\n", " 'M': '16',\n", " 'algorithm': 'HNSW',\n", " 'data_type': 'FLOAT32',\n", " 'attribute_data_size': '32229974',\n", " 'index_data_size': '155131952',\n", " 'dimension': '512',\n", " 'data_count': '44072'}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Upsert data to Tair\n", "from tqdm.auto import tqdm\n", "\n", "batch_size = 200\n", "\n", "if client.tvs_get_index(index_name)[\"data_count\"] != str(44072):\n", " for i in tqdm(range(0, len(dataset), batch_size)):\n", " i_end = min(i+batch_size, len(dataset))\n", "\n", " id_batch = ids[i:i_end]\n", " attr_batch = dataset.iloc[i:i_end].to_dict(orient=\"records\")\n", " vector_batch = model.encode(images[i:i_end]).tolist()\n", " text_batch = texts[i:i_end]\n", "\n", " # upload the documents to tair hybrid index\n", " for key, vector, text, attr in zip(id_batch, vector_batch, text_batch, attr_batch):\n", " attr[\"TEXT\"] = text\n", " client.tvs_hset(index_name, str(key), vector, False, **attr)\n", "\n", "# show index description after uploading the documents\n", "client.tvs_get_index(index_name)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from IPython.core.display import HTML\n", "from io import BytesIO\n", "from base64 import b64encode\n", "\n", "# function to display product images\n", "def display_result(image_batch):\n", " figures = []\n", " for img in image_batch:\n", " b = BytesIO()\n", " img.save(b, format='png')\n", " figures.append(f'''\n", "
\n", " \n", "
\n", " ''')\n", " return HTML(data=f'''\n", "
\n", " {''.join(figures)}\n", "
\n", " ''')" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(b'33506', 61.000003814697266), (b'27229', 62.0), (b'24615', 62.999996185302734), (b'17617', 64.0), (b'12830', 65.0), (b'4191', 65.99273681640625), (b'2313', 67.0), (b'33273', 68.0), (b'39327', 69.0), (b'38942', 70.0), (b'38943', 71.0), (b'33138', 72.0), (b'11260', 73.0), (b'4989', 73.9908676147461), (b'12850', 75.0), (b'41747', 76.0), (b'39845', 77.0), (b'36213', 78.0), (b'39846', 79.0), (b'24945', 80.0)]\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "topk = 20\n", "text = \"Green Kidswear\"\n", "vector = model.encode([text])[0]\n", "filter_str = None\n", "kwargs = {\"TEXT\" : text, \"hybrid_ratio\" : 0.9999}\n", "result = client.tvs_knnsearch(index_name, topk, vector, False, filter_str, **kwargs)\n", "top_img = [images[id_pos[int(item[0])]] for item in result]\n", "print(result)\n", "display_result(top_img)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(b'4322', 1.0), (b'4904', 2.0), (b'3806', 3.0), (b'4907', 4.0), (b'8334', 5.0), (b'4989', 5.549999713897705), (b'3818', 6.999999523162842), (b'8341', 8.0), (b'8335', 9.0), (b'4927', 10.0), (b'4191', 10.860758781433105), (b'4209', 11.0), (b'4898', 12.0), (b'4305', 13.999999046325684), (b'4925', 14.999999046325684), (b'4311', 16.0), (b'4972', 17.0), (b'4192', 18.0), (b'5000', 19.0), (b'4899', 20.0)]\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "topk = 20\n", "text = \"Green Kidswear\"\n", "vector = model.encode([text])[0]\n", "filter_str = None\n", "kwargs = {\"TEXT\" : text, \"hybrid_ratio\" : 0.0001}\n", "result = client.tvs_knnsearch(index_name, topk, vector, False, filter_str, **kwargs)\n", "top_img = [images[id_pos[int(item[0])]] for item in result]\n", "print(result)\n", "display_result(top_img)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "topk = 20\n", "text = \"Green Kidswear\"\n", "vector = model.encode([text])[0]\n", "filter_str = None\n", "kwargs = {\"TEXT\" : text, \"hybrid_ratio\" : 0.5}\n", "result = client.tvs_knnsearch(index_name, topk, vector, False, filter_str, **kwargs)\n", "top_img = [images[id_pos[int(item[0])]] for item in result]\n", "display_result(top_img)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(b'4191', 32.93181610107422), (b'8366', 41.976192474365234), (b'24615', 61.000003814697266), (b'4323', 61.000003814697266), (b'8361', 62.0), (b'17617', 62.0), (b'4904', 62.999996185302734), (b'4927', 64.0), (b'2313', 64.0), (b'4209', 65.0), (b'33138', 65.0), (b'8335', 66.0), (b'38942', 66.0), (b'38943', 67.0), (b'4898', 67.0), (b'4311', 68.0), (b'12850', 68.0), (b'41747', 69.0), (b'39846', 70.0), (b'8334', 70.0)]\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "topk = 20\n", "text = \"Green Kidswear\"\n", "vector = model.encode([text])[0]\n", "filter_str = \"subCategory == \\\"Topwear\\\"\"\n", "kwargs = {\"TEXT\" : text, \"hybrid_ratio\" : 0.5}\n", "result = client.tvs_knnsearch(index_name, topk, vector, False, filter_str, **kwargs)\n", "print(result)\n", "top_img = [images[id_pos[int(item[0])]] for item in result]\n", "display_result(top_img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 4 }