klasocki commited on
Commit
f42ec01
·
1 Parent(s): a5fed35

Migrate to FastAPI from Flask, Docker works

Browse files
.dockerignore CHANGED
@@ -1,3 +1,7 @@
1
  .idea
2
  data/
3
  .pytest_cache
 
 
 
 
 
1
  .idea
2
  data/
3
  .pytest_cache
4
+ .gitignore
5
+ README.txt
6
+ openapi.yaml
7
+
Dockerfile CHANGED
@@ -5,9 +5,10 @@ WORKDIR /comma-fixer
5
  COPY requirements.txt .
6
  RUN pip install -r requirements.txt
7
 
8
- COPY . .
 
9
 
10
- COPY ~/.cache/huggingface/hub/models--oliverguhr--fullstop-punctuation-multilang-large/ ~/.cache/huggingface/hub/models--oliverguhr--fullstop-punctuation-multilang-large/
11
 
12
  EXPOSE 8000
13
- #CMD gunicorn "app:app"
 
5
  COPY requirements.txt .
6
  RUN pip install -r requirements.txt
7
 
8
+ COPY src/baseline.py src/baseline.py
9
+ RUN python src/baseline.py # This pre-downloads models and tokenizers
10
 
11
+ COPY . .
12
 
13
  EXPOSE 8000
14
+ CMD uvicorn "app:app" --port 8000 --host "0.0.0.0"
app.py CHANGED
@@ -1,32 +1,35 @@
1
- from flask import Flask, request, jsonify, make_response
2
- from src.baseline import fix_commas, create_baseline_pipeline
 
3
  import logging
4
 
5
  logger = logging.Logger(__name__)
6
  logging.basicConfig(level=logging.INFO)
7
 
8
- app = Flask(__name__)
9
- logging.info('Loading the baseline model...')
10
- app.baseline_pipeline = create_baseline_pipeline()
11
 
12
 
13
- @app.route('/', methods=['GET'])
14
- def root():
15
  return ("Welcome to the comma fixer. Send a POST request to /fix-commas or /baseline/fix-commas with a string "
16
  "'s' in the JSON body to try "
17
  "out the functionality.")
18
 
19
 
20
- @app.route('/baseline/fix-commas/', methods=['POST'])
21
- def fix_commas_with_baseline():
22
  json_field_name = 's'
23
- data = request.get_json()
24
  if json_field_name in data:
25
- return make_response(jsonify({json_field_name: fix_commas(app.baseline_pipeline, data['s'])}), 200)
 
26
  else:
27
- return make_response(f"Parameter '{json_field_name}' missing", 400)
 
 
28
 
29
 
30
  if __name__ == '__main__':
31
- app.run(debug=True)
32
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI, HTTPException
3
+ from src.baseline import BaselineCommaFixer
4
  import logging
5
 
6
  logger = logging.Logger(__name__)
7
  logging.basicConfig(level=logging.INFO)
8
 
9
+ app = FastAPI() #TODO router?
10
+ logger.info('Loading the baseline model...')
11
+ app.baseline_model = BaselineCommaFixer()
12
 
13
 
14
+ @app.get('/')
15
+ async def root():
16
  return ("Welcome to the comma fixer. Send a POST request to /fix-commas or /baseline/fix-commas with a string "
17
  "'s' in the JSON body to try "
18
  "out the functionality.")
19
 
20
 
21
+ @app.post('/baseline/fix-commas/')
22
+ async def fix_commas_with_baseline(data: dict):
23
  json_field_name = 's'
 
24
  if json_field_name in data:
25
+ logger.debug('Fixing commas.')
26
+ return {json_field_name: app.baseline_model.fix_commas(data['s'])}
27
  else:
28
+ msg = f"Text '{json_field_name}' missing"
29
+ logger.debug(msg)
30
+ raise HTTPException(status_code=400, detail=msg)
31
 
32
 
33
  if __name__ == '__main__':
34
+ uvicorn.run("app:app", reload=True, port=8000)
35
 
docker-compose.yml CHANGED
@@ -1,28 +1,30 @@
 
 
1
  services:
2
- nginx:
3
- image: nginx:latest
4
- container_name: nginx
5
- volumes:
6
- - ./:/comma-fixer
7
- - ./nginx.conf:/etc/nginx/conf.d/default.conf
8
- ports:
9
- - 8001:80
10
- networks:
11
- - my-network
12
- depends_on:
13
- - flask
14
- flask:
15
  build:
16
  context: ./
17
  dockerfile: Dockerfile
18
  container_name: comma-fixer
19
- command: gunicorn --bind 0.0.0.0:8000 "app:app" --timeout 300 #--workers 4
20
  volumes:
21
  - ./:/comma-fixer
22
- networks:
23
- my-network:
24
- aliases:
25
- - flask-app
26
-
27
- networks:
28
- my-network:
 
1
+ version: '3.1'
2
+
3
  services:
4
+ # nginx:
5
+ # image: nginx:latest
6
+ # container_name: nginx
7
+ # volumes:
8
+ # - ./:/comma-fixer
9
+ # - ./nginx.conf:/etc/nginx/conf.d/default.conf
10
+ # ports:
11
+ # - 8001:80
12
+ # networks:
13
+ # - my-network
14
+ # depends_on:
15
+ # - flask
16
+ comma-fixer:
17
  build:
18
  context: ./
19
  dockerfile: Dockerfile
20
  container_name: comma-fixer
21
+ command: uvicorn --host 0.0.0.0 --port 8000 "app:app"
22
  volumes:
23
  - ./:/comma-fixer
24
+ # networks:
25
+ # my-network:
26
+ # aliases:
27
+ # - comma-fixer
28
+ #
29
+ #networks:
30
+ # my-network:
requirements.txt CHANGED
@@ -1,9 +1,11 @@
1
- flask == 2.2.2
2
- gunicorn == 21.2.0
 
3
  pytest
4
- torch == 2.0.1
5
- transformers == 4.31.0
 
6
 
7
  # for the tokenizer of the baseline model
8
- protobuf == 4.24.0
9
  sentencepiece==0.1.99
 
1
+ fastapi==0.101.1
2
+ gunicorn==21.2.0
3
+ uvicorn==0.23.2
4
  pytest
5
+ httpx
6
+ torch==2.0.1
7
+ transformers==4.31.0
8
 
9
  # for the tokenizer of the baseline model
10
+ protobuf==4.24.0
11
  sentencepiece==0.1.99
src/baseline.py CHANGED
@@ -1,19 +1,23 @@
1
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
2
 
3
 
4
- def create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
 
 
 
 
 
 
 
 
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
  model = AutoModelForTokenClassification.from_pretrained(model_name)
7
  return pipeline('ner', model=model, tokenizer=tokenizer)
8
 
9
 
10
- def fix_commas(ner_pipeline: NerPipeline, s: str) -> str:
11
- return _fix_commas_based_on_pipeline_output(
12
- ner_pipeline(_remove_punctuation(s)),
13
- s
14
- )
15
-
16
-
17
  def _remove_punctuation(s: str) -> str:
18
  to_remove = ".,?-:"
19
  for char in to_remove:
@@ -29,7 +33,7 @@ def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s:
29
  current_offset = _find_current_token(current_offset, i, pipeline_json, result)
30
  if _should_insert_comma(i, pipeline_json):
31
  result = result[:current_offset] + ',' + result[current_offset:]
32
- current_offset += 1
33
  return result
34
 
35
 
@@ -43,3 +47,7 @@ def _find_current_token(current_offset, i, pipeline_json, result, new_word_indic
43
  # Find the current word in the result string, starting looking at current offset
44
  current_offset = result.find(current_word, current_offset) + len(current_word)
45
  return current_offset
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
2
 
3
 
4
+ class BaselineCommaFixer:
5
+ def __init__(self):
6
+ self._ner = _create_baseline_pipeline()
7
+
8
+ def fix_commas(self, s: str) -> str:
9
+ return _fix_commas_based_on_pipeline_output(
10
+ self._ner(_remove_punctuation(s)),
11
+ s
12
+ )
13
+
14
+
15
+ def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForTokenClassification.from_pretrained(model_name)
18
  return pipeline('ner', model=model, tokenizer=tokenizer)
19
 
20
 
 
 
 
 
 
 
 
21
  def _remove_punctuation(s: str) -> str:
22
  to_remove = ".,?-:"
23
  for char in to_remove:
 
33
  current_offset = _find_current_token(current_offset, i, pipeline_json, result)
34
  if _should_insert_comma(i, pipeline_json):
35
  result = result[:current_offset] + ',' + result[current_offset:]
36
+ current_offset += 1
37
  return result
38
 
39
 
 
47
  # Find the current word in the result string, starting looking at current offset
48
  current_offset = result.find(current_word, current_offset) + len(current_word)
49
  return current_offset
50
+
51
+
52
+ if __name__ == "__main__":
53
+ BaselineCommaFixer() # to pre-download the model and tokenizer
tests/test_baseline.py CHANGED
@@ -1,10 +1,10 @@
1
  import pytest
2
- from baseline import create_baseline_pipeline, fix_commas, _remove_punctuation
3
 
4
 
5
  @pytest.fixture()
6
- def baseline_pipeline():
7
- yield create_baseline_pipeline()
8
 
9
 
10
  @pytest.mark.parametrize(
@@ -14,8 +14,8 @@ def baseline_pipeline():
14
  'This test string should not have any commas inside it.',
15
  'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
16
  )
17
- def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
18
- result = fix_commas(baseline_pipeline, s=test_input)
19
  assert result == test_input
20
 
21
 
@@ -32,8 +32,8 @@ def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_inp
32
  ['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
33
  'I had no Creativity left therefore, I come here and write useless examples for this test.']]
34
  )
35
- def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
36
- result = fix_commas(baseline_pipeline, s=test_input)
37
  assert result == expected
38
 
39
 
 
1
  import pytest
2
+ from baseline import BaselineCommaFixer, _remove_punctuation
3
 
4
 
5
  @pytest.fixture()
6
+ def baseline_fixer():
7
+ yield BaselineCommaFixer()
8
 
9
 
10
  @pytest.mark.parametrize(
 
14
  'This test string should not have any commas inside it.',
15
  'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
16
  )
17
+ def test_fix_commas_leaves_correct_strings_unchanged(baseline_fixer, test_input):
18
+ result = baseline_fixer.fix_commas(s=test_input)
19
  assert result == test_input
20
 
21
 
 
32
  ['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
33
  'I had no Creativity left therefore, I come here and write useless examples for this test.']]
34
  )
35
+ def test_fix_commas_fixes_incorrect_commas(baseline_fixer, test_input, expected):
36
+ result = baseline_fixer.fix_commas(s=test_input)
37
  assert result == expected
38
 
39
 
tests/test_integration.py CHANGED
@@ -1,21 +1,17 @@
1
- from flask import json
2
  import pytest
3
 
4
  from app import app
5
- from baseline import create_baseline_pipeline
6
 
7
 
8
  @pytest.fixture()
9
  def client():
10
- app.config["DEBUG"] = True
11
- app.config["TESTING"] = True
12
- app.baseline_pipeline = create_baseline_pipeline()
13
- yield app.test_client()
14
 
15
 
16
  def test_fix_commas_fails_on_no_parameter(client):
17
  response = client.post('/baseline/fix-commas/')
18
- assert response.status_code == 400
19
 
20
 
21
  def test_fix_commas_fails_on_wrong_parameters(client):
@@ -33,7 +29,7 @@ def test_fix_commas_correct_string_unchanged(client, test_input: str):
33
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
34
 
35
  assert response.status_code == 200
36
- assert response.get_json().get('s') == test_input
37
 
38
 
39
  @pytest.mark.parametrize(
@@ -46,7 +42,7 @@ def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
46
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
47
 
48
  assert response.status_code == 200
49
- assert response.get_json().get('s') == expected
50
 
51
 
52
  def test_with_a_very_long_string(client):
@@ -54,4 +50,4 @@ def test_with_a_very_long_string(client):
54
  response = client.post('/baseline/fix-commas/', json={'s': s})
55
 
56
  assert response.status_code == 200
57
- assert response.get_json().get('s') == s
 
1
+ from fastapi.testclient import TestClient
2
  import pytest
3
 
4
  from app import app
 
5
 
6
 
7
  @pytest.fixture()
8
  def client():
9
+ yield TestClient(app)
 
 
 
10
 
11
 
12
  def test_fix_commas_fails_on_no_parameter(client):
13
  response = client.post('/baseline/fix-commas/')
14
+ assert response.status_code == 422
15
 
16
 
17
  def test_fix_commas_fails_on_wrong_parameters(client):
 
29
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
30
 
31
  assert response.status_code == 200
32
+ assert response.json().get('s') == test_input
33
 
34
 
35
  @pytest.mark.parametrize(
 
42
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
43
 
44
  assert response.status_code == 200
45
+ assert response.json().get('s') == expected
46
 
47
 
48
  def test_with_a_very_long_string(client):
 
50
  response = client.post('/baseline/fix-commas/', json={'s': s})
51
 
52
  assert response.status_code == 200
53
+ assert response.json().get('s') == s