henryz2004 commited on
Commit
8ca8e46
1 Parent(s): d44e56d

custom handler

Browse files
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+
5
+ # Distribution / packaging
6
+ build/
7
+ dist/
8
+ *.egg-info/
9
+
10
+ # Environments
11
+ .env
12
+ .venv
13
+ env/
14
+ venv/
15
+
16
+ # PyCharm
17
+ ./.idea/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/distilbert-base-uncased-emotion-custominf.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="dialignment" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="dialignment" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="dialignment" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/distilbert-base-uncased-emotion-custominf.iml" filepath="$PROJECT_DIR$/.idea/distilbert-base-uncased-emotion-custominf.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
handler.py CHANGED
@@ -2,29 +2,18 @@ from typing import Dict, List, Any
2
  from transformers import pipeline
3
  import holidays
4
 
5
-
6
- class EndpointHandler:
7
  def __init__(self, path=""):
8
  self.pipeline = pipeline("text-classification", model=path)
9
  self.holidays = holidays.US()
10
 
11
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
- """
13
- data args:
14
- inputs (:obj: `str`)
15
- date (:obj: `str`)
16
- Return:
17
- A :obj:`list` | `dict`: will be serialized and returned
18
- """
19
- # get inputs
20
- inputs = data.pop("inputs", data)
21
- # get additional date field
22
- date = data.pop("date", None)
23
 
24
- # check if date exists and if it is a holiday
25
- if date is not None and date in self.holidays:
26
- return [{"label": "happy", "score": 1}]
27
 
28
- # run normal prediction
 
 
29
  prediction = self.pipeline(inputs)
30
- return prediction
 
2
  from transformers import pipeline
3
  import holidays
4
 
5
+ class EndpointHandler():
 
6
  def __init__(self, path=""):
7
  self.pipeline = pipeline("text-classification", model=path)
8
  self.holidays = holidays.US()
9
 
10
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ inputs = data.get("inputs", data)
13
+ date = data.get("date", None)
 
14
 
15
+ if date is not None and date in self.holidays:
16
+ return [{"label":"happy", "score":1}]
17
+
18
  prediction = self.pipeline(inputs)
19
+ return prediction
test_handler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ handler = EndpointHandler(".")
4
+
5
+ non_holiday_payload = {"inputs": "I am so sad", "date": "2022-01-04"}
6
+ holiday_payload = {"inputs": "I am so sad", "date": "2022-07-04"}
7
+
8
+ non_holiday_response = handler(non_holiday_payload)
9
+ holiday_response = handler(holiday_payload)
10
+
11
+ print("non_holiday_response:", non_holiday_response)
12
+ print("holiday_response:", holiday_response)