|
from typing import List, Tuple |
|
import torch |
|
from SciAssist import ReferenceStringParsing |
|
|
|
device = "gpu" if torch.cuda.is_available() else "cpu" |
|
rsp_pipeline = ReferenceStringParsing(os_name="nt") |
|
|
|
|
|
def rsp_for_str(input, dehyphen=False) -> List[Tuple[str, str]]: |
|
results = rsp_pipeline.predict(input, type="str", dehyphen=dehyphen) |
|
output = [] |
|
for res in results: |
|
for token, tag in zip(res["tokens"], res["tags"]): |
|
output.append((token, tag)) |
|
output.append(("\n\n", None)) |
|
return output |
|
|
|
|
|
def rsp_for_file(input, dehyphen=False) -> List[Tuple[str, str]]: |
|
if input == None: |
|
return None |
|
filename = input.name |
|
|
|
if filename[-4:] == ".txt": |
|
results = rsp_pipeline.predict(filename, type="txt", dehyphen=dehyphen, save_results=False) |
|
elif filename[-4:] == ".pdf": |
|
results = rsp_pipeline.predict(filename, dehyphen=dehyphen, save_results=False) |
|
else: |
|
return [("File Format Error !", None)] |
|
|
|
output = [] |
|
for res in results: |
|
for token, tag in zip(res["tokens"], res["tags"]): |
|
output.append((token, tag)) |
|
output.append(("\n\n", None)) |
|
return output |
|
|