Sheng Lei commited on
Commit
28de1fd
1 Parent(s): 5ea700d

Add application file

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ trained_model/
2
+ dist/
3
+ .idea
.gitignore~ ADDED
File without changes
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+
8
+ WORKDIR /app
9
+
10
+ COPY --chown=user ./requirements.txt requirements.txt
11
+
12
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
+
14
+ COPY --chown=user . /app
15
+
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,7 @@
1
- ---
2
- title: Restricted Item Detector
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
+ # hackweek2024-sup-genai-tools
 
 
 
 
 
 
 
 
2
 
3
+ ## Poetry setup
4
+ to start, first install dependency by
5
+ `$ poetry install`
6
+ then you can start the venv by
7
+ `$ poetry shell`
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+
3
+ app = FastAPI()
4
+
5
+ @app.get("/")
6
+ def greet_json():
7
+ return {"Hello": "World!"}
csv/block_items.csv ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date,Host,Service,Blocked Cart Items,Message
2
+ "2024-06-25T21:51:35.523Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""Promotional Email GiftCard $10\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}, {\""name\"":\""Promotional Email GiftCard $10\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Promotional Email GiftCard $10, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}, ShoppingCartProduct{name=Promotional Email GiftCard $10, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}]; {customer_token=C_gqed91mw3, flow_token=4s6ZZrxNlMdUWqbN59nXHURXi, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
3
+ "2024-06-25T21:43:14.348Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""$100 Vanilla® Visa® Gift Box Gift Card (plus $5.44 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/55844027\""}]""","Found restricted items: [ShoppingCartProduct{name=$100 Vanilla® Visa® Gift Box Gift Card (plus $5.44 Purchase Fee), url=https://www.walmart.com/ip/seort/55844027}]; {customer_token=C_86dpxgy6b, flow_token=4c546519-bc7c-4ae7-bbf0-5daf158351b1, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=null}"
4
+ "2024-06-25T21:43:05.816Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""Promotional Email GiftCard $10\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Promotional Email GiftCard $10, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}]; {customer_token=C_fsexdpmbh, flow_token=fRqPMaEwTcPsgLuLQ8mCvf9df, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
5
+ "2024-06-25T21:35:36.381Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options\"",\""url\"":\""https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1\"",\""image_url\"":\""https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg\""}]""","Found restricted items: [ShoppingCartProduct{name=GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options, url=https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1, image_url=https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg}]; {customer_token=C_5cpf4hygz, flow_token=73658682-d99e-4742-b6b4-8ebb7f1543ad, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_ih4h4tzb, sup_token=BRAND_djwngg0n5zluxwp131plevfvw}"
6
+ "2024-06-25T21:35:24.733Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options\"",\""url\"":\""https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1\"",\""image_url\"":\""https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg\""}]""","Found restricted items: [ShoppingCartProduct{name=GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options, url=https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1, image_url=https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg}]; {customer_token=C_5cpf4hygz, flow_token=73658682-d99e-4742-b6b4-8ebb7f1543ad, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_ih4h4tzb, sup_token=BRAND_djwngg0n5zluxwp131plevfvw}"
7
+ "2024-06-25T21:35:20.107Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options\"",\""url\"":\""https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1\"",\""image_url\"":\""https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg\""}]""","Found restricted items: [ShoppingCartProduct{name=GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options, url=https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1, image_url=https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg}]; {customer_token=C_5cpf4hygz, flow_token=73658682-d99e-4742-b6b4-8ebb7f1543ad, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_ih4h4tzb, sup_token=BRAND_djwngg0n5zluxwp131plevfvw}"
8
+ "2024-06-25T21:35:01.987Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options\"",\""url\"":\""https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1\"",\""image_url\"":\""https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg\""}]""","Found restricted items: [ShoppingCartProduct{name=GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options, url=https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1, image_url=https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg}]; {customer_token=C_5cpf4hygz, flow_token=73658682-d99e-4742-b6b4-8ebb7f1543ad, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_ih4h4tzb, sup_token=BRAND_djwngg0n5zluxwp131plevfvw}"
9
+ "2024-06-25T21:34:07.013Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""$50 Vanilla Visa eGift Card\"",\""url\"":\""https://www.samsclub.com/p/vanilla-e-gift-visa-various-amount/prod25810992\"",\""image_url\"":\""https://scene7.samsclub.com/is/image/samsclub/0079936655659_A?wid=200&hei=200\""}]""","Found restricted items: [ShoppingCartProduct{name=$50 Vanilla Visa eGift Card, url=https://www.samsclub.com/p/vanilla-e-gift-visa-various-amount/prod25810992, image_url=https://scene7.samsclub.com/is/image/samsclub/0079936655659_A?wid=200&hei=200}]; {customer_token=C_haed91mp3, flow_token=Q81SuACF2yWsd6Fn6PRWDsBK4, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_dsjbdob6, sup_token=null}"
10
+ "2024-06-25T21:33:34.552Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options\"",\""url\"":\""https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1\"",\""image_url\"":\""https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg\""}]""","Found restricted items: [ShoppingCartProduct{name=GPS7000 GPS Tracker for Vehicles - Hidden Tracking Device for Any Vehicle - Easy Installation on Car's Battery- 10 Days of Service - Subscription Required - Low Cost Subscription Plan Options, url=https://www.amazon.com/gp/aw/d/B0BYK99LZC/ref=ox_sc_act_title_delete_4?smid=AWU3BDL6BD5T5&psc=1, image_url=https://m.media-amazon.com/images/I/61E6Cex5dUL._AC_AA135_.jpg}]; {customer_token=C_5cpf4hygz, flow_token=73658682-d99e-4742-b6b4-8ebb7f1543ad, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_ih4h4tzb, sup_token=BRAND_djwngg0n5zluxwp131plevfvw}"
11
+ "2024-06-25T21:31:29.788Z","""i-0c92416a905681db6""","""cash-commerce-browser""","""[{\""name\"":\""$200 Vanilla Visa Shiny Bow Gift Card (plus $6.88 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/1042177970\""}]""","Found restricted items: [ShoppingCartProduct{name=$200 Vanilla Visa Shiny Bow Gift Card (plus $6.88 Purchase Fee), url=https://www.walmart.com/ip/seort/1042177970}]; {customer_token=C_smecd0mhm, flow_token=wnHQN78zduV3T55TsN8s0CKHA, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
12
+ "2024-06-25T21:30:38.963Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""$200 Vanilla Visa Shiny Bow Gift Card (plus $6.88 Purchase Fee)\"",\""image_url\"":\""https://i5.walmartimages.com/seo/200-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-6-88-Purchase-Fee_1cc86583-2dcb-4677-bca4-16f685df6170.d43209d754cbb7ab3c09354d6cf906a7.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}]""","Found restricted items: [ShoppingCartProduct{name=$200 Vanilla Visa Shiny Bow Gift Card (plus $6.88 Purchase Fee), image_url=https://i5.walmartimages.com/seo/200-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-6-88-Purchase-Fee_1cc86583-2dcb-4677-bca4-16f685df6170.d43209d754cbb7ab3c09354d6cf906a7.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}]; {customer_token=C_smecd0mhm, flow_token=CCvc7KfXlUWnhdqtcxORrG4Sp, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
13
+ "2024-06-25T21:29:54.170Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""$200 Vanilla Visa Shiny Bow Gift Card (plus $6.88 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/1042177970\""}]""","Found restricted items: [ShoppingCartProduct{name=$200 Vanilla Visa Shiny Bow Gift Card (plus $6.88 Purchase Fee), url=https://www.walmart.com/ip/seort/1042177970}]; {customer_token=C_smecd0mhm, flow_token=oucI3kCcOLy76Hiv8Tb2vxgfa, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
14
+ "2024-06-25T21:26:19.296Z","""i-0c92416a905681db6""","""cash-commerce-browser""","""[{\""name\"":\""$200 Vanilla® Visa® Gift Box Gift Card (plus $6.88 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/20625678\""}]""","Found restricted items: [ShoppingCartProduct{name=$200 Vanilla® Visa® Gift Box Gift Card (plus $6.88 Purchase Fee), url=https://www.walmart.com/ip/seort/20625678}]; {customer_token=C_002v20mt1, flow_token=iNJAgMr3PvHc07lbZAkdrnEIn, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
15
+ "2024-06-25T21:20:06.661Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""Promotional Email GiftCard $20\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}, {\""name\"":\""Promotional Email GiftCard $10\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Promotional Email GiftCard $20, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}, ShoppingCartProduct{name=Promotional Email GiftCard $10, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}]; {customer_token=C_ssdwxqyqz, flow_token=roXsSw6mIgIqe5XjF8bO8tH2o, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
16
+ "2024-06-25T21:18:25.022Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""Xbox $100 Gift Card - [Digital]\"",\""url\"":\""https://www.walmart.com/ip/seort/48695417\""}]""","Found restricted items: [ShoppingCartProduct{name=Xbox $100 Gift Card - \[Digital\], url=https://www.walmart.com/ip/seort/48695417}]; {customer_token=C_h9ccehmmq, flow_token=ZichXe9bjvCk2NYseQpasHyxW, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
17
+ "2024-06-25T21:15:07.437Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee), quantity 2\"",\""image_url\"":\""https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}, {\""name\"":\""$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee), quantity 2\"",\""image_url\"":\""https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}]""","Found restricted items: [ShoppingCartProduct{name=$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee)\, quantity 2, image_url=https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}, ShoppingCartProduct{name=$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee)\, quantity 2, image_url=https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}]; {customer_token=C_gawfj8yrr, flow_token=08e30143-ac4d-428a-963a-f9bbd10cfdb5, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
18
+ "2024-06-25T21:15:01.385Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee)\"",\""image_url\"":\""https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}, {\""name\"":\""$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee), quantity 2\"",\""image_url\"":\""https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}]""","Found restricted items: [ShoppingCartProduct{name=$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee), image_url=https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}, ShoppingCartProduct{name=$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee)\, quantity 2, image_url=https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}]; {customer_token=C_gawfj8yrr, flow_token=08e30143-ac4d-428a-963a-f9bbd10cfdb5, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
19
+ "2024-06-25T21:14:58.239Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee)\"",\""image_url\"":\""https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}, {\""name\"":\""$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee)\"",\""image_url\"":\""https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}]""","Found restricted items: [ShoppingCartProduct{name=$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee), image_url=https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}, ShoppingCartProduct{name=$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee), image_url=https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}]; {customer_token=C_gawfj8yrr, flow_token=08e30143-ac4d-428a-963a-f9bbd10cfdb5, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
20
+ "2024-06-25T21:13:44.815Z","""i-0c92416a905681db6""","""cash-commerce-browser""","""[{\""name\"":\""Xbox $100 Gift Card - [Digital]\"",\""url\"":\""https://www.walmart.com/ip/seort/48695417\""}]""","Found restricted items: [ShoppingCartProduct{name=Xbox $100 Gift Card - \[Digital\], url=https://www.walmart.com/ip/seort/48695417}]; {customer_token=C_h9ccehmmq, flow_token=W9fV7QLzgEeyuDVEf2101TEnm, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
21
+ "2024-06-25T21:13:36.779Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""Home Depot Birthday Cupcake eGift\"",\""url\"":\""https://www.homedepot.com/gift-cards/p/Home-Depot-Birthday-Cupcake-eGift/53L1QLAXWKVF06WJ8DLHBYT6S8/5J0QS6XYWSCJ3H9HG3F6YSRL2C\"",\""image_url\"":\""https://images.thdstatic.com/giftcards/catalog/53L1QLAXWKVF06WJ8DLHBYT6S8/xxlarge/5J0QS6XYWSCJ3H9HG3F6YSRL2C_0608202315:26:21.PNG\""}]""","Found restricted items: [ShoppingCartProduct{name=Home Depot Birthday Cupcake eGift, url=https://www.homedepot.com/gift-cards/p/Home-Depot-Birthday-Cupcake-eGift/53L1QLAXWKVF06WJ8DLHBYT6S8/5J0QS6XYWSCJ3H9HG3F6YSRL2C, image_url=https://images.thdstatic.com/giftcards/catalog/53L1QLAXWKVF06WJ8DLHBYT6S8/xxlarge/5J0QS6XYWSCJ3H9HG3F6YSRL2C_0608202315:26:21.PNG}]; {customer_token=C_ssd2wpm6w, flow_token=V53QOpItzVeaNwXwPC6cewpnt, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_kiqt4xl6, sup_token=BRAND_78qp2zjkqjto3v3xxr5ibueyv}"
22
+ "2024-06-25T21:12:33.161Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""$100 Vanilla® Visa® Gift Box Gift Card (plus $5.44 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/55844027\"",\""image_url\"":\""https://i5.walmartimages.com/seo/100-Vanilla-Visa-Gift-Box-Gift-Card-plus-5-44-Purchase-Fee_d1b5b130-1ae0-4690-a1ee-4e06ff34661a.1e16d3785d608d00f8636a7a91f40182.jpeg?odnHeight=72&odnWidth=72&odnBg=FFFFFF\""}, {\""name\"":\""$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/1529427774\"",\""image_url\"":\""https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=72&odnWidth=72&odnBg=FFFFFF\""}, {\""name\"":\""$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/1314241502\"",\""image_url\"":\""https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=72&odnWidth=72&odnBg=FFFFFF\""}]""","Found restricted items: [ShoppingCartProduct{name=$100 Vanilla® Visa® Gift Box Gift Card (plus $5.44 Purchase Fee), url=https://www.walmart.com/ip/seort/55844027, image_url=https://i5.walmartimages.com/seo/100-Vanilla-Visa-Gift-Box-Gift-Card-plus-5-44-Purchase-Fee_d1b5b130-1ae0-4690-a1ee-4e06ff34661a.1e16d3785d608d00f8636a7a91f40182.jpeg?odnHeight=72&odnWidth=72&odnBg=FFFFFF}, ShoppingCartProduct{name=$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/1529427774, image_url=https://i5.walmartimages.com/seo/50-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-94-Purchase-Fee_97c67191-c411-43a8-8174-5bed3ccfa2ed.bfc91450d0b09abcfbab003e6edf9fcf.jpeg?odnHeight=72&odnWidth=72&odnBg=FFFFFF}, ShoppingCartProduct{name=$25 Vanilla Visa Shiny Bow Gift Card (plus $3.44 Purchase Fee), url=https://www.walmart.com/ip/seort/1314241502, image_url=https://i5.walmartimages.com/seo/25-Vanilla-Visa-Shiny-Bow-Gift-Card-plus-3-44-Purchase-Fee_852b53b6-7090-407d-a47b-06618b04d244.3c1e4f0aae77e8bf05fdd0660a354377.jpeg?odnHeight=72&odnWidth=72&odnBg=FFFFFF}]; {customer_token=C_gawfj8yrr, flow_token=08e30143-ac4d-428a-963a-f9bbd10cfdb5, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
23
+ "2024-06-25T21:09:57.657Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""Birthday Celebration Target Giftcard\"",\""url\"":\""https://www.target.com/p/birthday-celebration-target-giftcard--no-aasa/-/A-81480862\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_c2ea8817-d6da-437c-9d85-c39b2e96953d?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Birthday Celebration Target Giftcard, url=https://www.target.com/p/birthday-celebration-target-giftcard--no-aasa/-/A-81480862, image_url=https://target.scene7.com/is/image/Target//GUEST_c2ea8817-d6da-437c-9d85-c39b2e96953d?qlt=80&fmt=pjpeg}]; {customer_token=C_8xdwa1m6n, flow_token=9RBhrYZk8g7j2WDYyM3GbfdES, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
24
+ "2024-06-25T21:09:49.051Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""Birthday Celebration Target Giftcard\"",\""url\"":\""https://www.target.com/p/birthday-celebration-target-giftcard--no-aasa/-/A-81480862\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_c2ea8817-d6da-437c-9d85-c39b2e96953d?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Birthday Celebration Target Giftcard, url=https://www.target.com/p/birthday-celebration-target-giftcard--no-aasa/-/A-81480862, image_url=https://target.scene7.com/is/image/Target//GUEST_c2ea8817-d6da-437c-9d85-c39b2e96953d?qlt=80&fmt=pjpeg}]; {customer_token=C_8xdwa1m6n, flow_token=J7liXzVSAPSDIaK4CWVLC7bBi, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
25
+ "2024-06-25T21:09:35.160Z","""i-003fbf26706b7af3e""","""cash-commerce-browser""","""[{\""name\"":\""Promotional Email GiftCard $10\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Promotional Email GiftCard $10, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}]; {customer_token=C_fsexdpmbh, flow_token=jVDz28RrrN9tAkim71XsZGZMD, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
26
+ "2024-06-25T21:09:35.022Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""Birthday Celebration Target Giftcard\"",\""url\"":\""https://www.target.com/p/birthday-celebration-target-giftcard--no-aasa/-/A-81480862\"",\""image_url\"":\""\""}]""","Found restricted items: [ShoppingCartProduct{name=Birthday Celebration Target Giftcard, url=https://www.target.com/p/birthday-celebration-target-giftcard--no-aasa/-/A-81480862, image_url=}]; {customer_token=C_8xdwa1m6n, flow_token=9pZoMnaqRmVXuWQ9P9pSLefbO, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
27
+ "2024-06-25T21:09:07.723Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""Nintendo Switch Family Online 12 Month Gift Card [Digital]\"",\""url\"":\""https://www.walmart.com/ip/seort/875981752\""}]""","Found restricted items: [ShoppingCartProduct{name=Nintendo Switch Family Online 12 Month Gift Card \[Digital\], url=https://www.walmart.com/ip/seort/875981752}]; {customer_token=C_gq2ekammv, flow_token=03c0b437-4fc5-4967-9692-3efddd21a0d6, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
28
+ "2024-06-25T21:08:46.049Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""Promotional Email GiftCard $10\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Promotional Email GiftCard $10, image_url=https://target.scene7.com/is/image/Target//GUEST_337c880d-e2c0-4794-9784-962c63cf0646?qlt=80&fmt=pjpeg}]; {customer_token=C_fsexdpmbh, flow_token=e5xlrZ3F45YirfY5AHNFlAvtV, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
29
+ "2024-06-25T20:55:28.994Z","""i-044ba56b80a95063d""","""cash-commerce-browser""","""[{\""name\"":\""$50 Vanilla® Mastercard® Celebration Dots Gift Card (plus $3.94 Purchase Fee)\"",\""image_url\"":\""https://i5.walmartimages.com/seo/50-Vanilla-Mastercard-Celebration-Dots-Gift-Card-plus-3-94-Purchase-Fee_04e596b9-6d0a-41c9-96ed-7bdf11468cd2.d9453dcf7909a341bf31ee4dce04c833.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF\""}]""","Found restricted items: [ShoppingCartProduct{name=$50 Vanilla® Mastercard® Celebration Dots Gift Card (plus $3.94 Purchase Fee), image_url=https://i5.walmartimages.com/seo/50-Vanilla-Mastercard-Celebration-Dots-Gift-Card-plus-3-94-Purchase-Fee_04e596b9-6d0a-41c9-96ed-7bdf11468cd2.d9453dcf7909a341bf31ee4dce04c833.jpeg?odnHeight=48&odnWidth=48&odnBg=FFFFFF}]; {customer_token=C_nywqkaygt, flow_token=IOAzdzUmK2FLIN1KiCt8xQZQd, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
30
+ "2024-06-25T20:55:26.016Z","""i-0c92416a905681db6""","""cash-commerce-browser""","""[{\""name\"":\""$50 Vanilla® Mastercard® Celebration Dots Gift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/46518449\""}]""","Found restricted items: [ShoppingCartProduct{name=$50 Vanilla® Mastercard® Celebration Dots Gift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/46518449}]; {customer_token=C_nywqkaygt, flow_token=dEIHJrVnTIH6xsmTYspEfBcle, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
31
+ "2024-06-25T20:55:13.996Z","""i-0c92416a905681db6""","""cash-commerce-browser""","""[{\""name\"":\""$50 Vanilla® Mastercard® Celebration Dots Gift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/46518449\""}, {\""name\"":\""$50 Vanilla® Visa® eGift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/761861368\""}]""","Found restricted items: [ShoppingCartProduct{name=$50 Vanilla® Mastercard® Celebration Dots Gift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/46518449}, ShoppingCartProduct{name=$50 Vanilla® Visa® eGift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/761861368}]; {customer_token=C_nywqkaygt, flow_token=MJgyeF7Eew00jGOPMLCG2rHyD, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
32
+ "2024-06-25T20:52:39.628Z","""i-0c92416a905681db6""","""cash-commerce-browser""","""[{\""name\"":\""$50 Vanilla® Visa® eGift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/761861368\""}]""","Found restricted items: [ShoppingCartProduct{name=$50 Vanilla® Visa® eGift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/761861368}]; {customer_token=C_nywqkaygt, flow_token=WJm0vCQG5TcMjVjTvpMqrfOGs, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
33
+ "2024-06-25T20:52:36.059Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/1529427774\""}, {\""name\"":\""$50 Vanilla® Visa® eGift Card (plus $3.94 Purchase Fee)\"",\""url\"":\""https://www.walmart.com/ip/seort/761861368\""}]""","Found restricted items: [ShoppingCartProduct{name=$50 Vanilla Visa Shiny Bow Gift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/1529427774}, ShoppingCartProduct{name=$50 Vanilla® Visa® eGift Card (plus $3.94 Purchase Fee), url=https://www.walmart.com/ip/seort/761861368}]; {customer_token=C_nywqkaygt, flow_token=a5YVHg3kkZn8fb3e3TT1b9GlZ, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_fopbedhc, sup_token=BRAND_4sxo1et5hf8sug8lh500413rj}"
34
+ "2024-06-25T20:52:02.686Z","""i-0ec7dc300989b0f58""","""cash-commerce-browser""","""[{\""name\"":\""Delta Airlines Wedding $250 Gift Card (Email Delivery)\"",\""url\"":\""https://www.target.com/p/delta-airlines-wedding--250-gift-card--email-delivery---no-aasa/-/A-89537719\"",\""image_url\"":\""https://target.scene7.com/is/image/Target//GUEST_0de9a4fb-7f0e-4d85-806f-e841d7664a59?qlt=80&fmt=pjpeg\""}]""","Found restricted items: [ShoppingCartProduct{name=Delta Airlines Wedding $250 Gift Card (Email Delivery), url=https://www.target.com/p/delta-airlines-wedding--250-gift-card--email-delivery---no-aasa/-/A-89537719, image_url=https://target.scene7.com/is/image/Target//GUEST_0de9a4fb-7f0e-4d85-806f-e841d7664a59?qlt=80&fmt=pjpeg}]; {customer_token=C_80exppmwa, flow_token=pnOyvjetUVkuZfhIOQwOQCINH, payment_method=PAYMENT_METHOD_SINGLE_USE_PAYMENT, merchant_token=M_3mfvyukp, sup_token=BRAND_57fqf0l740hpnbxev71c6d56e}"
genai_SDK/Seq2Seq.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torch.utils.data import DataLoader, random_split
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ import time
10
+ from .Utilities import LanguageDataset
11
+
12
+ class Seq2Seq():
13
+ """
14
+ Base class for Seq2Seq (text-generation models). This class will be inherited by wrappers of transformers like GPT2
15
+ and T5.
16
+
17
+ Attributes:
18
+
19
+ Methods:
20
+
21
+ """
22
+
23
+ def __init__(self, gpu=0, max_length=0, model_path=None):
24
+
25
+ # Load Seq2Seq to device based on available hardware
26
+ if torch.cuda.is_available():
27
+ self.device = torch.device('cuda')
28
+ else:
29
+ try:
30
+ self.device = torch.device('mps') # Apple Silicon
31
+ except Exception:
32
+ self.device = torch.device('cpu')
33
+
34
+ # GPU that model will run on
35
+ self.gpu = gpu
36
+
37
+ # Model specs
38
+ if model_path: self.model = torch.load(model_path).to(self.device)
39
+ else: self.model = None
40
+ self.model_name = ""
41
+ self.tokenizer = None
42
+ self.max_length = max_length
43
+
44
+ # Training specs
45
+ self.train_loader = None
46
+ self.valid_loader = None
47
+ self.results = pd.DataFrame(columns=['epoch', 'model_arch', 'batch_size', 'gpu', 'training_loss', 'validation_loss', 'epoch_duration_sec'])
48
+
49
+ def load_data(self, df, batch_size, train_ratio=0.8):
50
+ self.batch_size = batch_size
51
+ dataset = LanguageDataset(df, self.tokenizer)
52
+ train_size = int(0.8*len(dataset))
53
+ valid_size = len(dataset) - train_size
54
+ train_data, valid_data = random_split(dataset, [train_size, valid_size])
55
+ self.max_length = dataset.max_length
56
+ self.train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
57
+ self.valid_loader = DataLoader(valid_data, batch_size=self.batch_size)
58
+
59
+ """ Return training results """
60
+ def summary(self):
61
+ return self.results
62
+
63
+ """ Save model to path """
64
+ def to_pt(self, path):
65
+ torch.save(self.model, path)
66
+
67
+
68
+ class GPT2(Seq2Seq):
69
+ """
70
+ This is the GPT2 implementation of Seq2Seq.
71
+ """
72
+
73
+ def __init__(self, gpu, model_name, batch_size=16):
74
+ super().__init__(gpu, max_length=0)
75
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
76
+ self.model_name = model_name
77
+ self.model = GPT2LMHeadModel.from_pretrained(self.model_name).to(self.device)
78
+ self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
79
+ self.tokenizer.pad_token = self.tokenizer.eos_token
80
+
81
+ def train(self, num_epochs=3, train_ratio=0.8):
82
+ criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
83
+ optimizer = optim.Adam(self.model.parameters(), lr=5e-4)
84
+
85
+ # Init a results dataframe
86
+
87
+ results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
88
+ 'training_loss', 'validation_loss', 'epoch_duration_sec'])
89
+ # The training loop
90
+ for epoch in range(num_epochs):
91
+ start_time = time.time() # Start the timer for the epoch
92
+
93
+ # Training
94
+ ## This line tells the self.model we're in 'learning mode'
95
+ self.model.train()
96
+ epoch_training_loss = 0
97
+ train_iterator = tqdm(self.train_loader,
98
+ desc=f"Training Epoch {epoch + 1}/{num_epochs} Batch Size: {self.batch_size}, Transformer: {self.model_name}")
99
+ for batch in train_iterator:
100
+ optimizer.zero_grad()
101
+ inputs = batch['input_ids'].squeeze(1).to(self.device)
102
+ targets = inputs.clone()
103
+ outputs = self.model(input_ids=inputs, labels=targets)
104
+ loss = outputs.loss
105
+ loss.backward()
106
+ optimizer.step()
107
+ train_iterator.set_postfix({'Training Loss': loss.item()})
108
+ epoch_training_loss += loss.item()
109
+ avg_epoch_training_loss = epoch_training_loss / len(train_iterator)
110
+
111
+ # Validation
112
+ ## This line below tells the self.model to 'stop learning'
113
+ self.model.eval()
114
+ epoch_validation_loss = 0
115
+ total_loss = 0
116
+ valid_iterator = tqdm(self.valid_loader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}")
117
+ with torch.no_grad():
118
+ for batch in valid_iterator:
119
+ inputs = batch['input_ids'].squeeze(1).to(self.device)
120
+ targets = inputs.clone()
121
+ outputs = self.model(input_ids=inputs, labels=targets)
122
+ loss = outputs.loss
123
+ total_loss += loss
124
+ valid_iterator.set_postfix({'Validation Loss': loss.item()})
125
+ epoch_validation_loss += loss.item()
126
+
127
+ avg_epoch_validation_loss = epoch_validation_loss / len(self.valid_loader)
128
+
129
+ end_time = time.time() # End the timer for the epoch
130
+ epoch_duration_sec = end_time - start_time # Calculate the duration in seconds
131
+
132
+ new_row = {'transformer': self.model_name,
133
+ 'batch_size': self.batch_size,
134
+ 'gpu': self.gpu,
135
+ 'epoch': epoch + 1,
136
+ 'training_loss': avg_epoch_training_loss,
137
+ 'validation_loss': avg_epoch_validation_loss,
138
+ 'epoch_duration_sec': epoch_duration_sec} # Add epoch_duration to the dataframe
139
+
140
+ self.results.loc[len(self.results)] = new_row
141
+ print(f"Epoch: {epoch + 1}, Validation Loss: {total_loss / len(self.valid_loader)}")
142
+
143
+ def generate_text(self, input_str, top_k=16, top_p=0.95, temperature=1.0, repetition_penalty=1.2):
144
+ # Encode string to tokens
145
+ input_ids= self.tokenizer.encode(input_str, return_tensors='pt').to(self.device)
146
+
147
+ # Feed tokens to model and get outcome tokens
148
+ output = self.model.generate(
149
+ input_ids,
150
+ max_length=self.max_length,
151
+ num_return_sequences=1,
152
+ do_sample=True,
153
+ top_k=top_k,
154
+ top_p=top_p,
155
+ temperature=temperature,
156
+ repetition_penalty=repetition_penalty
157
+ )
158
+
159
+ # Decode tokens to string
160
+ decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
161
+ return decoded_output
162
+
163
+ class FlanT5(Seq2Seq):
164
+ """
165
+ This is the T5 implementation of Seq2Seq - it is designed to support T5 models of various sizes.
166
+ """
167
+ def __init__(self, gpu, model_name, batch_size=16):
168
+ super().__init__(gpu, max_length=0)
169
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
170
+ self.model_name = model_name
171
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
172
+ self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
173
+ self.tokenizer.pad_token = self.tokenizer.eos_token
174
+
175
+ def train(self, num_epochs=3, train_ratio=0.8):
176
+ criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
177
+ optimizer = optim.Adam(self.model.parameters(), lr=5e-4)
178
+
179
+ # Init a results dataframe
180
+
181
+ self.results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
182
+ 'training_loss', 'validation_loss', 'epoch_duration_sec'])
183
+ # The training loop
184
+ for epoch in range(num_epochs):
185
+ start_time = time.time() # Start the timer for the epoch
186
+
187
+ # Training
188
+ ## This line tells the model we're in 'learning mode'
189
+ self.model.train()
190
+ epoch_training_loss = 0
191
+ train_iterator = tqdm(self.train_loader,
192
+ desc=f"Training Epoch {epoch + 1}/{num_epochs} Batch Size: {self.batch_size}, Transformer: {self.model_name}")
193
+ for batch in train_iterator:
194
+ optimizer.zero_grad()
195
+ inputs = batch['input_ids'].squeeze(1).to(self.device)
196
+ targets = batch['labels'].squeeze(1).to(self.device)
197
+ outputs = self.model(input_ids=inputs, labels=targets)
198
+ loss = outputs.loss
199
+ loss.backward()
200
+ optimizer.step()
201
+ train_iterator.set_postfix({'Training Loss': loss.item()})
202
+ epoch_training_loss += loss.item()
203
+ avg_epoch_training_loss = epoch_training_loss / len(train_iterator)
204
+
205
+ # Validation
206
+ ## This line below tells the model to 'stop learning'
207
+ self.model.eval()
208
+ epoch_validation_loss = 0
209
+ total_loss = 0
210
+ valid_iterator = tqdm(self.valid_loader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}")
211
+ with torch.no_grad():
212
+ for batch in valid_iterator:
213
+ inputs = batch['input_ids'].squeeze(1).to(self.device)
214
+ targets = batch['labels'].squeeze(1).to(self.device)
215
+ outputs = self.model(input_ids=inputs, labels=targets)
216
+ loss = outputs.loss
217
+ total_loss += loss
218
+ valid_iterator.set_postfix({'Validation Loss': loss.item()})
219
+ epoch_validation_loss += loss.item()
220
+
221
+ avg_epoch_validation_loss = epoch_validation_loss / len(self.valid_loader)
222
+
223
+ end_time = time.time() # End the timer for the epoch
224
+ epoch_duration_sec = end_time - start_time # Calculate the duration in seconds
225
+
226
+ new_row = {'transformer': self.model_name,
227
+ 'batch_size': self.batch_size,
228
+ 'gpu': self.gpu,
229
+ 'epoch': epoch + 1,
230
+ 'training_loss': avg_epoch_training_loss,
231
+ 'validation_loss': avg_epoch_validation_loss,
232
+ 'epoch_duration_sec': epoch_duration_sec} # Add epoch_duration to the dataframe
233
+
234
+ self.results.loc[len(self.results)] = new_row
235
+ print(f"Epoch: {epoch + 1}, Validation Loss: {total_loss / len(self.valid_loader)}")
236
+
237
+ def generate_text(self, input_str, top_k=16, top_p=0.95, temperature=1.0, repetition_penalty=1.2):
238
+ # Encode input string into tensors via the FlanT5 tokenizer
239
+ input_ids = self.tokenizer.encode(input_str, return_tensors='pt', max_length=self.max_length, truncation=True).to(self.device)
240
+ # Run tensors through model to get output tensor values
241
+ output_ids = self.model.generate(input_ids,
242
+ max_length=self.max_length,
243
+ do_sample=True,
244
+ top_k=top_k,
245
+ top_p=top_p,
246
+ temperature=temperature,
247
+ repetition_penalty=repetition_penalty)
248
+ # Decode output tensors to text vi
249
+ output_str = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
250
+ return output_str
genai_SDK/Utilities.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import pandas as pd
4
+ from torch.utils.data import Dataset
5
+
6
+ class LanguageDataset(Dataset):
7
+ def __init__(self, df, tokenizer):
8
+ # Make sure data is compatible
9
+ if len(df.columns) !=2:
10
+ raise Exception("Dataset can only have two columns!")
11
+
12
+ self.data = df.to_dict(orient='records')
13
+ self.tokenizer = tokenizer
14
+
15
+ # set the length of smallest square needed
16
+ self.max_length = smallest_square_length(df)
17
+ self.labels = df.columns
18
+
19
+ def __len__(self):
20
+ return len(self.data)
21
+ def __getitem__(self, i):
22
+ X = self.data[i][self.labels[0]]
23
+ Y = self.data[i][self.labels[1]]
24
+ if str(type(self.tokenizer)) == "<class 'transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer'>":
25
+ return self.tokenizer.encode_plus(X + ' | ' + Y,
26
+ return_tensors='pt',
27
+ max_length = self.max_length,
28
+ padding='max_length',
29
+ truncation=True)
30
+ elif str(type(self.tokenizer)) == "<class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>":
31
+ input_tokens = self.tokenizer.encode_plus(
32
+ X,
33
+ max_length=self.max_length,
34
+ padding='max_length',
35
+ truncation=True,
36
+ return_tensors='pt'
37
+ )
38
+ target_tokens = self.tokenizer.encode_plus(
39
+ Y,
40
+ max_length=self.max_length,
41
+ padding='max_length',
42
+ truncation=True,
43
+ return_tensors='pt'
44
+ )
45
+ return {
46
+ 'input_ids': input_tokens['input_ids'].squeeze(),
47
+ # 'attention_mask': input_tokens['attention_mask'].squeeze(),
48
+ 'labels': target_tokens['input_ids'].squeeze()
49
+ }
50
+
51
+
52
+
53
+ def smallest_square_length(df):
54
+ col1 = df[df.columns[0]].astype(str).apply(lambda x: len(x)).max()
55
+ col2 = df[df.columns[1]].astype(str).apply(lambda x: len(x)).max()
56
+
57
+ max_length = max(col1, col2)
58
+
59
+ x = 2
60
+ while x < max_length:
61
+ x = x * 2
62
+
63
+ return x
64
+
65
+ def levenshtein_distance(str1, str2):
66
+ """
67
+ Computes the Levenshtein distance between two strings.
68
+ Parameters:
69
+ str1 (str): The first string.
70
+ str2 (str): The second string.
71
+ Returns:
72
+ int: The Levenshtein distance between the two strings.
73
+ """
74
+ m, n = len(str1), len(str2)
75
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
76
+
77
+ for i in range(m + 1):
78
+ dp[i][0] = i
79
+
80
+ for j in range(n + 1):
81
+ dp[0][j] = j
82
+
83
+ for i in range(1, m + 1):
84
+ for j in range(1, n + 1):
85
+ if str1[i - 1] == str2[j - 1]:
86
+ dp[i][j] = dp[i - 1][j - 1]
87
+ else:
88
+ dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
89
+
90
+ return dp[m][n]
91
+
92
+ def grid_search(model, tokenizer, input_str, topK_values, topP_values, temperature_values, repetition_penalty_values, expected_output):
93
+ """
94
+ Conducts a grid search over specified hyperparameters to find the best text generation settings (GPT series).
95
+ Parameters:
96
+ - model: The pre-trained model used for text generation.
97
+ - tokenizer: The tokenizer associated with the model.
98
+ - input_str: The input string to the model for text generation.
99
+ - topK_values: A list of integer values for the topK sampling hyperparameter.
100
+ - topP_values: A list of float values for the topP (nucleus) sampling hyperparameter.
101
+ - temperature_values: A list of float values for the temperature setting of the model.
102
+ - repetition_penalty_values: A list of float values for penalizing repetitions in the generated text.
103
+ - expected_output: The expected output string against which generated texts are evaluated using the Levenshtein distance.
104
+ Returns:
105
+ - results: A pandas DataFrame containing the combination of hyperparameters, the generated output for each combination, and its Levenshtein distance from the expected output.
106
+ Notes:
107
+ - The function prints out the best hyperparameters found during the search, based on the smallest Levenshtein distance.
108
+ - Levenshtein distance measures the number of edits required to transform one string into another.
109
+ """
110
+ results = pd.DataFrame(columns=['topK', 'topP', 'temperature', 'repetition_penalty', 'generated_output', 'levenshtein_distance'])
111
+ min_distance = 9999999
112
+ for topK in topK_values:
113
+ for topP in topP_values:
114
+ for temperature in temperature_values:
115
+ for repetition_penalty in repetition_penalty_values:
116
+ # try:
117
+ generated_output = model.generate_text(input_str, topK, topP, temperature, repetition_penalty)
118
+ # print(generated_output)
119
+ distance = levenshtein_distance(generated_output, expected_output)
120
+ if distance < min_distance:
121
+ print(f'topK={topK}, topP={topP}, temperature={temperature}, repetition_penalty={repetition_penalty}, levenshtein_distance={distance}')
122
+ min_distance = distance
123
+
124
+ new_row = {'topK': topK,
125
+ 'topP': topP,
126
+ 'temperature': temperature,
127
+ 'repetition_penalty': repetition_penalty,
128
+ 'generated_output': generated_output,
129
+ 'levenshtein_distance': distance
130
+ }
131
+ results.loc[len(results)] = new_row
132
+
133
+ return results.sort_values(by='levenshtein_distance', ascending=True)
134
+
135
+
136
+ def to_coreml(gpt_model, path=''):
137
+ import torch
138
+
139
+ device = torch.device('mps')
140
+
141
+ if torch.cuda.is_available():
142
+ device = torch.device('cuda')
143
+ else:
144
+ try:
145
+ device = torch.device('mps') # Apple Silicon
146
+ except Exception:
147
+ device = torch.device('cpu')
148
+ if path != '': lm_head_model = torch.load(path, map_location=device)
149
+ else: lm_head_model = gpt_model.model
150
+
151
+ """
152
+ Recreate the Core ML model from scratch using
153
+ coremltools' neural_network.NeuralNetworkBuilder
154
+ """
155
+ import coremltools
156
+ import coremltools.models.datatypes as datatypes
157
+ from coremltools.models import neural_network as neural_network
158
+ from coremltools.models.utils import save_spec
159
+ import numpy as np
160
+ import torch
161
+ model_name = 'model'
162
+
163
+ model = lm_head_model.transformer
164
+
165
+ wte = model.wte.weight.data.cpu().numpy().transpose() # shape (768, 50257) /!\ i hate this
166
+ wpe = model.wpe.weight.data.cpu().numpy().transpose() # shape (768, 1024)
167
+
168
+ sequence_length = 128
169
+ steps = model.config.n_layer
170
+
171
+ # build model
172
+ input_features = [
173
+ ('input_ids', datatypes.Array(sequence_length)),
174
+ ('position_ids', datatypes.Array(sequence_length)),
175
+ ]
176
+ output_features = [('output_logits', None)]
177
+
178
+ builder = neural_network.NeuralNetworkBuilder(
179
+ input_features,
180
+ output_features,
181
+ mode=None,
182
+ disable_rank5_shape_mapping=True,
183
+ )
184
+ builder.add_expand_dims(
185
+ name='input_ids_expanded_to_rank5',
186
+ input_name='input_ids',
187
+ output_name='input_ids_expanded_to_rank5',
188
+ axes=(1, 2, 3, 4)
189
+ )
190
+ builder.add_expand_dims(
191
+ name='position_ids_expanded_to_rank5',
192
+ input_name='position_ids',
193
+ output_name='position_ids_expanded_to_rank5',
194
+ axes=(1, 2, 3, 4)
195
+ )
196
+ builder.add_embedding(
197
+ name='token_embeddings',
198
+ input_name='input_ids_expanded_to_rank5',
199
+ output_name='token_embeddings',
200
+ W=wte,
201
+ b=None,
202
+ input_dim=50257,
203
+ output_channels=768,
204
+ has_bias=False,
205
+ )
206
+ builder.add_embedding(
207
+ name='positional_embeddings',
208
+ input_name='position_ids_expanded_to_rank5',
209
+ output_name='positional_embeddings',
210
+ W=wpe,
211
+ b=None,
212
+ input_dim=1024,
213
+ output_channels=768,
214
+ has_bias=False,
215
+ )
216
+
217
+ # Input:, Output: (seq, 1, 768, 1, 1)
218
+ builder.add_add_broadcastable(
219
+ name='embeddings_addition',
220
+ input_names=['token_embeddings', 'positional_embeddings'],
221
+ output_name=f'{0}_previous_block'
222
+ )
223
+
224
+ for i in range(steps):
225
+ print(i)
226
+ ln_weight = model.h[i].ln_1.weight.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
227
+ ln_bias = model.h[i].ln_1.bias.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
228
+ ln_epsilon = model.h[i].ln_1.eps
229
+
230
+ builder.add_mvn(
231
+ name=f"{i}_block_ln_1",
232
+ input_name=f"{i}_previous_block",
233
+ # output_name=f"{i}_block_ln_1_output",
234
+ output_name=f"{i}_block_ln_1",
235
+ across_channels=True,
236
+ normalize_variance=True,
237
+ epsilon=ln_epsilon
238
+ )
239
+
240
+ builder.add_scale(
241
+ name=f"{i}_block_ln_1_scaled",
242
+ input_name=f"{i}_block_ln_1",
243
+ output_name=f"{i}_block_ln_1_scaled",
244
+ W=ln_weight,
245
+ b=ln_bias,
246
+ has_bias=True,
247
+ shape_scale=[768],
248
+ shape_bias=[768]
249
+ )
250
+
251
+ builder.add_transpose(
252
+ name=f"{i}_block_ln_1_reshape",
253
+ input_name=f"{i}_block_ln_1_scaled",
254
+ output_name=f"{i}_block_ln_1_scaled_transposed",
255
+ axes=(1, 0, 2, 3, 4)
256
+ )
257
+
258
+
259
+ conv_1D_bias = model.h[i].attn.c_attn.bias.data.cpu().numpy().reshape((1, 1, 2304, 1, 1))
260
+ conv_1D_weights = model.h[i].attn.c_attn.weight.cpu().data.numpy().transpose().reshape((1, 768, 2304, 1, 1))
261
+
262
+ builder.add_inner_product(
263
+ name=f"{i}_block_attn_conv",
264
+ input_name=f"{i}_block_ln_1_scaled_transposed",
265
+ output_name=f"{i}_block_attn_conv",
266
+ input_channels=768,
267
+ output_channels=2304,
268
+ W=conv_1D_weights,
269
+ b=conv_1D_bias,
270
+ has_bias=True
271
+ )
272
+
273
+ builder.add_split(
274
+ name=f"{i}_block_attn_qkv_split",
275
+ input_name=f"{i}_block_attn_conv",
276
+ output_names=[f"{i}_block_attn_q", f"{i}_block_attn_k", f"{i}_block_attn_v"]
277
+ )
278
+
279
+ builder.add_rank_preserving_reshape(
280
+ name=f"{i}_block_attn_q_reshape",
281
+ input_name=f"{i}_block_attn_q",
282
+ output_name=f"{i}_block_attn_q_reshape",
283
+ output_shape=(1, 1, sequence_length, 12, 64)
284
+ )
285
+
286
+ builder.add_transpose(
287
+ name=f"{i}_block_attn_q_reshape_permuted",
288
+ input_name=f"{i}_block_attn_q_reshape",
289
+ output_name=f"{i}_block_attn_q_reshape_permuted",
290
+ axes=(0, 1, 3, 2, 4)
291
+ )
292
+
293
+ builder.add_rank_preserving_reshape(
294
+ name=f"{i}_block_attn_k_reshape",
295
+ input_name=f"{i}_block_attn_k",
296
+ output_name=f"{i}_block_attn_k_reshape",
297
+ output_shape=(1, 1, sequence_length, 12, 64)
298
+ )
299
+
300
+ builder.add_transpose(
301
+ name=f"{i}_block_attn_k_reshape_permuted",
302
+ input_name=f"{i}_block_attn_k_reshape",
303
+ output_name=f"{i}_block_attn_k_reshape_permuted",
304
+ axes=(0, 1, 3, 4, 2)
305
+ )
306
+
307
+ builder.add_rank_preserving_reshape(
308
+ name=f"{i}_block_attn_v_reshape",
309
+ input_name=f"{i}_block_attn_v",
310
+ output_name=f"{i}_block_attn_v_reshape",
311
+ output_shape=(1, 1, sequence_length, 12, 64)
312
+ )
313
+
314
+ builder.add_transpose(
315
+ name=f"{i}_block_attn_v_reshape_permuted",
316
+ input_name=f"{i}_block_attn_v_reshape",
317
+ output_name=f"{i}_block_attn_v_reshape_permuted",
318
+ axes=(0, 1, 3, 2, 4)
319
+ )
320
+
321
+ builder.add_batched_mat_mul(
322
+ name=f"{i}_block_attn_qv_matmul",
323
+ input_names=[f"{i}_block_attn_q_reshape_permuted", f"{i}_block_attn_k_reshape_permuted"],
324
+ output_name=f"{i}_block_attn_qv_matmul"
325
+ )
326
+
327
+ builder.add_scale(
328
+ name=f"{i}_block_attn_qv_matmul_scaled",
329
+ input_name=f"{i}_block_attn_qv_matmul",
330
+ output_name=f"{i}_block_attn_qv_matmul_scaled",
331
+ W=np.array(1/8),
332
+ b=0,
333
+ has_bias=False
334
+ )
335
+
336
+ bias_0 = model.h[i].attn.bias
337
+ nd = ns = sequence_length
338
+ b = (model.h[i].attn.bias[:, :, ns-nd:ns, :ns]).unsqueeze(0)
339
+
340
+ builder.add_scale(
341
+ name=f"{i}_block_attn_bias",
342
+ input_name=f"{i}_block_attn_qv_matmul_scaled",
343
+ output_name=f"{i}_block_attn_bias",
344
+ W=b,
345
+ b=None,
346
+ has_bias=False,
347
+ shape_scale=[1, sequence_length, sequence_length]
348
+ )
349
+
350
+ bias_constant_0 = -1e4 * torch.logical_not(b)
351
+
352
+ builder.add_bias(
353
+ name=f"{i}_block_attn_afterbias",
354
+ input_name=f"{i}_block_attn_bias",
355
+ output_name=f"{i}_block_attn_afterbias",
356
+ # output_name=f"output_logits",
357
+ b=bias_constant_0,
358
+ shape_bias=[1, sequence_length, sequence_length],
359
+ )
360
+
361
+ builder.add_squeeze(
362
+ name=f"{i}_squeezit",
363
+ input_name=f"{i}_block_attn_afterbias",
364
+ output_name=f"{i}_squeezit",
365
+ axes=[0, 1]
366
+ )
367
+
368
+ builder.add_softmax(
369
+ name=f"{i}_block_attn_softmax",
370
+ input_name=f"{i}_squeezit",
371
+ output_name=f"{i}_block_attn_softmax",
372
+ )
373
+
374
+ builder.add_expand_dims(
375
+ name=f"{i}_expandit",
376
+ input_name=f"{i}_block_attn_softmax",
377
+ output_name=f"{i}_expandit",
378
+ axes=[0, 1]
379
+ )
380
+
381
+ builder.add_batched_mat_mul(
382
+ name=f"{i}_block_full_attention",
383
+ input_names=[f"{i}_expandit", f"{i}_block_attn_v_reshape_permuted"],
384
+ output_name=f"{i}_block_full_attention"
385
+ )
386
+
387
+ builder.add_transpose(
388
+ name=f"{i}_block_full_attention_merged_t",
389
+ input_name=f"{i}_block_full_attention",
390
+ output_name=f"{i}_block_full_attention_merged_t",
391
+ axes=[0, 1, 3, 2, 4]
392
+ )
393
+
394
+ builder.add_rank_preserving_reshape(
395
+ name=f"{i}_block_full_attention_merged",
396
+ input_name=f"{i}_block_full_attention_merged_t",
397
+ output_name=f"{i}_block_full_attention_merged",
398
+ output_shape=[1, 1, 1, sequence_length, 768]
399
+ )
400
+
401
+ builder.add_transpose(
402
+ name=f"{i}_block_attn_conv_proj_t",
403
+ input_name=f"{i}_block_full_attention_merged",
404
+ output_name=f"{i}_block_attn_conv_proj_t",
405
+ axes=[0, 3, 4, 1, 2]
406
+ )
407
+
408
+ conv_1D_proj_bias = model.h[i].attn.c_proj.bias.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
409
+ conv_1D_proj_weights = model.h[i].attn.c_proj.weight.data.cpu().numpy().transpose().reshape((1, 768, 768, 1, 1))
410
+
411
+ # Input:, Output: (1, 3, 768, 1, 1)
412
+ builder.add_inner_product(
413
+ name=f"{i}_block_attn_conv_proj",
414
+ input_name=f"{i}_block_attn_conv_proj_t",
415
+ output_name=f"{i}_block_attn_conv_proj",
416
+ input_channels=768,
417
+ output_channels=768,
418
+ W=conv_1D_proj_weights,
419
+ b=conv_1D_proj_bias,
420
+ has_bias=True
421
+ )
422
+
423
+ # Input: (seq, 1, 768, 1, 1), Output: (1, seq, 768, 1, 1)
424
+ builder.add_transpose(
425
+ name=f"{i}_previous_block_t",
426
+ input_name=f'{i}_previous_block',
427
+ output_name=f"{i}_previous_block_t",
428
+ axes=[1, 0, 2, 3, 4]
429
+ )
430
+
431
+ # Input: [(1, seq, 768, 1, 1), (1, seq, 768, 1, 1)], Output: (1, seq, 768, 1, 1)
432
+ builder.add_add_broadcastable(
433
+ name=f"{i}_block_xa_sum",
434
+ input_names=[f"{i}_previous_block_t", f"{i}_block_attn_conv_proj"],
435
+ output_name=f"{i}_block_xa_sum",
436
+ # output_name=f"output_logits"
437
+ )
438
+
439
+ ln_2_weight = model.h[i].ln_2.weight.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
440
+ ln_2_bias = model.h[i].ln_2.bias.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
441
+ ln_2_epsilon = model.h[i].ln_2.eps
442
+
443
+ # Input: (1, seq, 768, 1, 1), Output:
444
+ builder.add_mvn(
445
+ name=f"{i}_block_ln_2",
446
+ input_name=f"{i}_block_xa_sum",
447
+ output_name=f"{i}_block_ln_2",
448
+ across_channels=True,
449
+ normalize_variance=True,
450
+ epsilon=ln_2_epsilon
451
+ )
452
+
453
+ builder.add_scale(
454
+ name=f"{i}_block_ln_2_scaled",
455
+ input_name=f"{i}_block_ln_2",
456
+ # output_name=f"output_logits",
457
+ output_name=f"{i}_block_ln_2_scaled",
458
+ W=ln_2_weight,
459
+ b=ln_2_bias,
460
+ has_bias=True,
461
+ shape_scale=[768],
462
+ shape_bias=[768]
463
+ )
464
+
465
+ mlp_conv_1D_fc_bias = model.h[i].mlp.c_fc.bias.data.cpu().numpy().reshape((1, 1, 3072, 1, 1))
466
+ mlp_conv_1D_fc_weights = model.h[i].mlp.c_fc.weight.data.cpu().numpy().transpose().reshape((1, 768, 3072, 1, 1))
467
+
468
+ # Input:, Output: (1, 3, 3072, 1, 1)
469
+ builder.add_inner_product(
470
+ name=f"{i}_block_mlp_conv_fc",
471
+ input_name=f"{i}_block_ln_2_scaled",
472
+ output_name=f"{i}_block_mlp_conv_fc",
473
+ # output_name=f"output_logits",
474
+ input_channels=768,
475
+ output_channels=3072,
476
+ W=mlp_conv_1D_fc_weights,
477
+ b=mlp_conv_1D_fc_bias,
478
+ has_bias=True
479
+ )
480
+
481
+ builder.add_gelu(
482
+ name=f"{i}_block_mlp_gelu",
483
+ input_name=f"{i}_block_mlp_conv_fc",
484
+ output_name=f"{i}_block_mlp_gelu",
485
+ # output_name=f"output_logits",
486
+ mode='TANH_APPROXIMATION'
487
+ )
488
+
489
+ mlp_conv_1D_proj_bias = model.h[i].mlp.c_proj.bias.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
490
+ mlp_conv_1D_proj_weights = model.h[i].mlp.c_proj.weight.data.cpu().numpy().transpose().reshape((1, 3072, 768, 1, 1))
491
+
492
+ # Input:, Output: (1, 3, 3072, 1, 1)
493
+ builder.add_inner_product(
494
+ name=f"{i}_block_mlp_conv_proj",
495
+ input_name=f"{i}_block_mlp_gelu",
496
+ output_name=f"{i}_block_mlp_conv_proj",
497
+ # output_name=f"output_logits",
498
+ input_channels=3072,
499
+ output_channels=768,
500
+ W=mlp_conv_1D_proj_weights,
501
+ b=mlp_conv_1D_proj_bias,
502
+ has_bias=True
503
+ )
504
+
505
+ builder.add_add_broadcastable(
506
+ name=f"{i}_block_xm_sum",
507
+ input_names=[f"{i}_block_xa_sum", f"{i}_block_mlp_conv_proj"],
508
+ # output_name=f"output_logits"
509
+ output_name=f"{i + 1}_previous_block_final"
510
+ )
511
+
512
+ builder.add_transpose(
513
+ name=f"{i}_block_xm_sum_t",
514
+ input_name=f"{i + 1}_previous_block_final",
515
+ output_name=f"{i + 1}_previous_block",
516
+ axes=[1, 0, 2, 3, 4]
517
+ )
518
+
519
+
520
+ ln_f_weight = model.ln_f.weight.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
521
+ ln_f_bias = model.ln_f.bias.data.cpu().numpy().reshape((1, 1, 768, 1, 1))
522
+ ln_f_epsilon = model.ln_f.eps
523
+
524
+ # Input: (1, seq, 768, 1, 1), Output:
525
+ builder.add_mvn(
526
+ name=f"ln_f",
527
+ input_name=f"{steps}_previous_block_final",
528
+ output_name=f"ln_f",
529
+ # output_name=f"output_logits",
530
+ across_channels=True,
531
+ normalize_variance=True,
532
+ epsilon=ln_f_epsilon
533
+ )
534
+
535
+ builder.add_scale(
536
+ name=f"ln_f_scaled",
537
+ input_name=f"ln_f",
538
+ output_name=f"ln_f_scaled",
539
+ # output_name=f"output_logits",
540
+ W=ln_f_weight,
541
+ b=ln_f_bias,
542
+ has_bias=True,
543
+ shape_scale=[768],
544
+ shape_bias=[768]
545
+ )
546
+
547
+ lm_head_weights = lm_head_model.lm_head.weight.data.cpu().numpy().reshape((1, 50257, 768, 1, 1))
548
+
549
+ builder.add_inner_product(
550
+ name="lm_head",
551
+ input_name="ln_f_scaled",
552
+ output_name="output_logits",
553
+ input_channels=768,
554
+ output_channels=50257,
555
+ W=lm_head_weights,
556
+ b=None,
557
+ has_bias=False
558
+ )
559
+
560
+ # compile spec to model
561
+ mlmodel = coremltools.models.MLModel(builder.spec)
562
+
563
+ save_spec(builder.spec, f'{model_name}-{sequence_length}-{steps}.mlmodel')
genai_SDK/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .Utilities import *
2
+ from .Seq2Seq import *
genai_SDK/__pycache__/Seq2Seq.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
genai_SDK/__pycache__/Utilities.cpython-312.pyc ADDED
Binary file (21.4 kB). View file
 
genai_SDK/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (241 Bytes). View file
 
main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict, Dataset
2
+ import pandas as pd
3
+ import ast
4
+
5
+ import genai_SDK
6
+ from genai_SDK.Seq2Seq import GPT2
7
+
8
+ # Load dataset from huggingface
9
+ dataset = load_dataset("erwanlc/cocktails_recipe_no_brand")
10
+
11
+ # Convert to a pandas dataframe
12
+ data = [{'title': item['title'], 'raw_ingredients': item['raw_ingredients']} for item in dataset['train']]
13
+ df = pd.DataFrame(data)
14
+
15
+ # Just extract the ingredient names, nothing else
16
+ df.raw_ingredients = df.raw_ingredients.apply(lambda x: ', '.join([y[1] for y in ast.literal_eval(x)]))
17
+ #display(df.head())
18
+
19
+ model = GPT2(gpu=0, model_name="distilgpt2")
20
+ model.load_data(df=df, batch_size=8)
21
+
22
+ model.train(num_epochs=2)
23
+
24
+ print(model.generate_text("Annual Planning"))
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "hackweek2024-sup-genai-tools"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Your Name <you@example.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.12"
10
+ torch = "^2.3.1"
11
+ torchtext = "^0.18.0"
12
+ transformers = "^4.41.2"
13
+ sentencepiece = "^0.2.0"
14
+ pandas = "^2.2.2"
15
+ tqdm = "^4.66.4"
16
+ datasets = "^2.20.0"
17
+ scikit-learn = "^1.5.0"
18
+ accelerate = "^0.31.0"
19
+ fastapi = "^0.111.0"
20
+ uvicorn = "^0.30.1"
21
+
22
+
23
+ [build-system]
24
+ requires = ["poetry-core"]
25
+ build-backend = "poetry.core.masonry.api"
requirement.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ torchtext
5
+ transformers
6
+ sentencepiece
7
+ pandas
8
+ tqdm
9
+ datasets
10
+ scikit
11
+ accelerate
restrictedItems/parse.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv, json
2
+
3
+ with open('csv/block_items.csv', newline='') as csvfile:
4
+ reader = csv.reader(csvfile, delimiter=',', quotechar='"')
5
+ for i, row in enumerate(reader):
6
+ if i == 0:
7
+ continue
8
+ items = json.loads(json.loads(row[3]))
9
+
10
+ for item in items:
11
+ print(item['name'])
restrictedItems/predict.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForSequenceClassification
2
+ import torch
3
+
4
+ # Load the trained model and tokenizer
5
+ model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/hackweek2024-sup-genai-tools/trained_model")
6
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
7
+
8
+ # Function to predict the class of a single input text
9
+ def predict(text):
10
+ # Preprocess the input text
11
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
12
+
13
+ # Make predictions
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+
17
+ # Get the predicted class
18
+ logits = outputs.logits
19
+ predicted_class = torch.argmax(logits, dim=1).item()
20
+
21
+ return predicted_class
22
+
23
+ label_map = {0: 'Allowed Item', 1: 'Restricted Item'}
24
+
25
+ def main():
26
+ while True:
27
+ # Prompting the user for input
28
+ user_input = input("Enter something: ")
29
+
30
+ predicted_class = predict(user_input)
31
+
32
+ # Map the predicted class to a human-readable label
33
+ predicted_label = label_map[predicted_class]
34
+
35
+ # Displaying the user input
36
+ print(f'The item "{user_input}" is classified as: "{predicted_label}"')
37
+
38
+ if __name__ == "__main__":
39
+ main()
restrictedItems/train.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
2
+ from sklearn.model_selection import train_test_split
3
+ import torch
4
+
5
+ # Load dataset
6
+ # Assume dataset is a list of tuples (text, label)
7
+ # dataset = [("item description 1", 0), ("item description 2", 1), ...]
8
+
9
+ restricted_dataset = [
10
+ ("Promotional Email GiftCard $10", 1),
11
+ ("$100 Vanilla® Visa", 1),
12
+ ("Promotional Email GiftCard $10", 1),
13
+ ("GPS7000 GPS Tracker for", 1),
14
+ ("$50 Vanilla Visa eGift", 1),
15
+ ("GPS7000 GPS Tracker for", 1),
16
+ ("$200 Vanilla Visa Shiny", 1),
17
+ ("$200 Vanilla Visa Shiny", 1),
18
+ ("$200 Vanilla Visa Shiny", 1),
19
+ ("$200 Vanilla® Visa", 1),
20
+ ("Promotional Email GiftCard $20", 1),
21
+ ("Xbox $100 Gift Card", 1),
22
+ ("$25 Vanilla Visa Shiny", 1),
23
+ ("$25 Vanilla Visa Shiny", 1),
24
+ ("$25 Vanilla Visa Shiny", 1),
25
+ ("Xbox $100 Gift Card", 1),
26
+ ("Home Depot Birthday Cupcake", 1),
27
+ ("$100 Vanilla® Visa", 1),
28
+ ("Birthday Celebration Target Giftcard", 1),
29
+ ("Birthday Celebration Target Giftcard", 1),
30
+ ("Promotional Email GiftCard $10", 1),
31
+ ("Birthday Celebration Target Giftcard", 1),
32
+ ("Nintendo Switch Family Online", 1),
33
+ ("Promotional Email GiftCard $10", 1),
34
+ ("$50 Vanilla® Mastercard", 1),
35
+ ("$50 Vanilla® Mastercard", 1),
36
+ ("$50 Vanilla® Mastercard", 1),
37
+ ("$50 Vanilla® Visa", 1),
38
+ ("$50 Vanilla Visa Shiny", 1),
39
+ ("Delta Airlines Wedding $250", 1)
40
+ ]
41
+
42
+
43
+ normal_dataset =[
44
+ ("Kerrygold Grass-Fed Pure Irish Garlic & Herb Butter Stick, 3", 0),
45
+ ("bettergoods Garlic, Parmesan, & Basil Butter, 3 oz", 0),
46
+ ("Birds Eye Savory Herb Riced Cauliflower, 10 oz (Frozen)", 0),
47
+ ("Great Value Root Blend, Beets, Carrots, Parsnips and Sweet Potatoes", 0),
48
+ ("Fresh Blueberries, 18 oz Container", 0),
49
+ ("Mixpresso 3 Piece Black Canisters Sets For The Kitchen, Kitchen Jars With", 0),
50
+ ("Freshness Guaranteed Chicken Breast Tenderloins, 2.25 - 3.2", 0),
51
+ ("Kiolbassa Smoked Meats Beef Hickory Smoked Sausage, 4 links - 13oz", 0),
52
+ ("Hot Pockets Frozen Snacks, Pepperoni Pizza Buttery Crust, 5 Sandwiches", 0),
53
+ ("Kool Aid Jammers Tropical Punch Kids Drink 0% Juice Box Pouches, 10", 0),
54
+ ("Frito-Lay Flavor Mix Variety Pack Snack Chips, 1oz Bags, 18 Count", 0),
55
+ ("State Fair Classic Corn Dogs, 42.7 oz, 16 Count", 0),
56
+ ("ASURION 2 Year Sporting Goods Protection Plan ($175 - $199.99)", 0),
57
+ ("6% Incline Walking Pad Treadmill 320+ lb Capacity, Under The Desk", 0),
58
+ ("Renpure Biotin & Collagen Thickening Conditioner for All Hair Types, 32 fl", 0),
59
+ ("Renpure Biotin & Collagen Thickening Hair Shampoo for All Hair Types, 32", 0),
60
+ ("eos Shea Better Body Lotion for Dry Skin, Vanilla Cashmere, 16 fl", 0),
61
+ ("Degree Ultra Clear Long Lasting Men's Antiperspirant Deodorant Dry Spray,", 0),
62
+ ("Tide PODS Liquid Laundry Detergent, Original Scent, HE Compatible, 42 Count", 0),
63
+ ("DEER PARK Brand 100% Natural Spring Water, 16.9-ounce", 0),
64
+ ("Great Value Milk Whole Vitamin D Gallon Plastic Jug", 0),
65
+ ("Jumbo Russet Potatoes Whole Fresh, 8 lb Bag", 0),
66
+ ("Great Value Butter Pecan Flavored Ice Cream, 16 fl oz", 0),
67
+ ("Beef Lean Stew Meat, 1.0 - 1.5 lb Tray", 0),
68
+ ("Great Value Spaghetti 16oz", 0),
69
+ ("Great Value Flavored with Meat Pasta Sauce, 24 oz", 0),
70
+ ("Kentucky Kernel Original Seasoned Flour, Coating Mix for Frying, Value Size", 0)
71
+ ]
72
+
73
+ dataset = restricted_dataset + normal_dataset
74
+
75
+ # Split dataset
76
+ train_texts, val_texts, train_labels, val_labels = train_test_split([item[0] for item in dataset], [item[1] for item in dataset], test_size=0.2)
77
+ import pdb; pdb.set_trace()
78
+
79
+ # Load pre-trained BERT tokenizer
80
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
81
+
82
+ # Tokenize data
83
+ train_encodings = tokenizer(train_texts, truncation=True, padding=True)
84
+ val_encodings = tokenizer(val_texts, truncation=True, padding=True)
85
+
86
+ # Convert to torch Dataset
87
+ class ShoppingCartDataset(torch.utils.data.Dataset):
88
+ def __init__(self, encodings, labels):
89
+ self.encodings = encodings
90
+ self.labels = labels
91
+
92
+ def __getitem__(self, idx):
93
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
94
+ item['labels'] = torch.tensor(self.labels[idx])
95
+ return item
96
+
97
+ def __len__(self):
98
+ return len(self.labels)
99
+
100
+ train_dataset = ShoppingCartDataset(train_encodings, train_labels)
101
+ val_dataset = ShoppingCartDataset(val_encodings, val_labels)
102
+
103
+ # Load pre-trained BERT model
104
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
105
+
106
+ # Training arguments
107
+ training_args = TrainingArguments(
108
+ output_dir='../results',
109
+ num_train_epochs=3,
110
+ per_device_train_batch_size=16,
111
+ per_device_eval_batch_size=16,
112
+ warmup_steps=500,
113
+ weight_decay=0.01,
114
+ logging_dir='./logs',
115
+ logging_steps=10,
116
+ )
117
+
118
+ # Trainer
119
+ trainer = Trainer(
120
+ model=model,
121
+ args=training_args,
122
+ train_dataset=train_dataset,
123
+ eval_dataset=val_dataset,
124
+ )
125
+
126
+ # Train model
127
+ trainer.train()
128
+
129
+ # Evaluate model
130
+ trainer.evaluate()
131
+
132
+ model.save_pretrained('trained_model')