hqms commited on
Commit
944c27e
·
1 Parent(s): bb2aaa7

initial commit

Browse files
Files changed (3) hide show
  1. Dockerfile +14 -0
  2. app.py +56 -0
  3. requirements.txt +120 -0
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ FROM python:3.9
3
+
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+
8
+ WORKDIR /app
9
+
10
+ COPY --chown=user ./requirements.txt requirements.txt
11
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
+
13
+ COPY --chown=user . /app
14
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from transformers import BertTokenizer, EncoderDecoderModel
5
+
6
+ tokenizer = BertTokenizer.from_pretrained("cahya/bert2gpt-indonesian-summarization")
7
+ tokenizer.bos_token = tokenizer.cls_token
8
+ tokenizer.eos_token = tokenizer.sep_token
9
+ model = EncoderDecoderModel.from_pretrained("cahya/bert2gpt-indonesian-summarization")
10
+
11
+
12
+ class Generate(BaseModel):
13
+ text: str
14
+
15
+ class Prompt(BaseModel):
16
+ text: str
17
+
18
+ def generate(prompt: str):
19
+ if prompt == '':
20
+ return Generate(text='Prompt not provided')
21
+ else:
22
+ # generate summary
23
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
24
+ summary_ids = model.generate(input_ids,
25
+ min_length=20,
26
+ max_length=80,
27
+ num_beams=10,
28
+ repetition_penalty=2.5,
29
+ length_penalty=1.0,
30
+ early_stopping=True,
31
+ no_repeat_ngram_size=2,
32
+ use_cache=True,
33
+ do_sample = True,
34
+ temperature = 0.8,
35
+ top_k = 50,
36
+ top_p = 0.95)
37
+
38
+ summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
+ return Generate(text=summary_text)
40
+
41
+ app = FastAPI()
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"],
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ @app.get('/')
51
+ def home():
52
+ return {'app':'Summarization', 'version': .1}
53
+
54
+ @app.post('/generate', response_model=Generate)
55
+ def inference(prompt: Prompt):
56
+ return generate(prompt.text)
requirements.txt ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.32.1
3
+ alembic==1.13.2
4
+ aniso8601==9.0.1
5
+ annotated-types==0.7.0
6
+ anyio==4.4.0
7
+ blinker==1.8.2
8
+ cachetools==5.3.3
9
+ certifi==2024.7.4
10
+ charset-normalizer==3.3.2
11
+ chex==0.1.82
12
+ click==8.1.7
13
+ cloudpickle==3.0.0
14
+ contourpy==1.2.1
15
+ cycler==0.12.1
16
+ Deprecated==1.2.14
17
+ dnspython==2.6.1
18
+ docker==7.1.0
19
+ email_validator==2.2.0
20
+ entrypoints==0.4
21
+ etils==1.5.2
22
+ exceptiongroup==1.2.1
23
+ fastapi==0.111.0
24
+ fastapi-cli==0.0.4
25
+ filelock==3.15.4
26
+ Flask==3.0.3
27
+ flax==0.7.0
28
+ fonttools==4.53.1
29
+ fsspec==2024.6.1
30
+ gitdb==4.0.11
31
+ GitPython==3.1.43
32
+ graphene==3.3
33
+ graphql-core==3.2.3
34
+ graphql-relay==3.2.0
35
+ gunicorn==22.0.0
36
+ h11==0.14.0
37
+ httpcore==1.0.5
38
+ httptools==0.6.1
39
+ httpx==0.27.0
40
+ huggingface-hub==0.23.4
41
+ idna==3.7
42
+ importlib_metadata==7.1.0
43
+ importlib_resources==6.4.0
44
+ itsdangerous==2.2.0
45
+ jax==0.4.13
46
+ jaxlib==0.4.13
47
+ Jinja2==3.1.4
48
+ joblib==1.4.2
49
+ kiwisolver==1.4.5
50
+ Mako==1.3.5
51
+ Markdown==3.6
52
+ markdown-it-py==3.0.0
53
+ MarkupSafe==2.1.5
54
+ matplotlib==3.9.1
55
+ mdurl==0.1.2
56
+ ml-dtypes==0.4.0
57
+ mlflow==2.14.2
58
+ mpmath==1.3.0
59
+ msgpack==1.0.8
60
+ nest-asyncio==1.6.0
61
+ networkx==3.2.1
62
+ numpy==1.26.4
63
+ opentelemetry-api==1.25.0
64
+ opentelemetry-sdk==1.25.0
65
+ opentelemetry-semantic-conventions==0.46b0
66
+ opt-einsum==3.3.0
67
+ optax==0.1.4
68
+ orbax-checkpoint==0.5.16
69
+ orjson==3.10.6
70
+ packaging==24.1
71
+ pandas==2.2.2
72
+ pillow==10.4.0
73
+ protobuf==4.25.3
74
+ psutil==6.0.0
75
+ pyarrow==15.0.2
76
+ pydantic==2.8.2
77
+ pydantic_core==2.20.1
78
+ Pygments==2.18.0
79
+ pyparsing==3.1.2
80
+ python-dateutil==2.9.0.post0
81
+ python-dotenv==1.0.1
82
+ python-multipart==0.0.9
83
+ pytz==2024.1
84
+ PyYAML==6.0.1
85
+ querystring-parser==1.2.4
86
+ regex==2024.5.15
87
+ requests==2.32.3
88
+ rich==13.7.1
89
+ safetensors==0.4.3
90
+ scikit-learn==1.5.1
91
+ scipy==1.12.0
92
+ shellingham==1.5.4
93
+ six==1.16.0
94
+ smmap==5.0.1
95
+ sniffio==1.3.1
96
+ SQLAlchemy==2.0.31
97
+ sqlparse==0.5.0
98
+ starlette==0.37.2
99
+ sympy==1.12.1
100
+ tensorstore==0.1.63
101
+ threadpoolctl==3.5.0
102
+ tokenizers==0.19.1
103
+ toolz==0.12.1
104
+ torch==2.3.1
105
+ torchaudio==2.3.1
106
+ torchvision==0.18.1
107
+ tqdm==4.66.4
108
+ transformers==4.42.3
109
+ typer==0.12.3
110
+ typing_extensions==4.12.2
111
+ tzdata==2024.1
112
+ ujson==5.10.0
113
+ urllib3==2.2.2
114
+ uvicorn[standard]
115
+ uvloop==0.19.0
116
+ watchfiles==0.22.0
117
+ websockets==12.0
118
+ Werkzeug==3.0.3
119
+ wrapt==1.16.0
120
+ zipp==3.19.2