Spaces:
Sleeping
Sleeping
Upload AI_models/FOREcasT/inference.py with huggingface_hub
Browse files
AI_models/FOREcasT/inference.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from datasets import load_dataset, Features, Value
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from diffusers import DiffusionPipeline
|
5 |
+
from tqdm import tqdm
|
6 |
+
from ..config import get_config, get_logger
|
7 |
+
from .load_data import data_collector
|
8 |
+
|
9 |
+
args = get_config(config_file="config_FOREcasT.ini")
|
10 |
+
logger = get_logger(args)
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def data_collector_inference(examples):
|
14 |
+
for example in examples:
|
15 |
+
ref, cut = example["ref"], example["cut"]
|
16 |
+
assert len(ref) >= args.ref1len and len(ref) >= args.ref2len, f"ref of length {len(ref)} is too short, please decrease ref1len={args.ref1len} and/or ref2len={args.ref2len} in inference arguments"
|
17 |
+
assert cut <= args.ref1len and len(ref) - cut <= args.ref2len, f"ref1len={args.ref1len} and/or ref2len={args.ref2len} is too short, please increase them to cover cut site {cut}"
|
18 |
+
assert cut >= args.FOREcasT_MAX_DEL_SIZE, f"ref upstream to cut ({cut}) is less than FOREcasT_MAX_DEL_SIZE ({args.FOREcasT_MAX_DEL_SIZE}), extend ref to upstream"
|
19 |
+
assert len(ref) - cut >= args.FOREcasT_MAX_DEL_SIZE, f"ref downstream to cut ({len(ref) - cut}) is less than FOREcasT_MAX_DEL_SIZE ({args.FOREcasT_MAX_DEL_SIZE}), extend ref to downstream"
|
20 |
+
return data_collector(examples, output_count=False)
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def inference(data_name=args.data_name, data_files="inference.json.gz"):
|
24 |
+
logger.info("load inference data")
|
25 |
+
ds = load_dataset('json', data_files=data_files, features=Features({
|
26 |
+
'ref': Value('string'),
|
27 |
+
'cut': Value('int16')
|
28 |
+
}))["train"]
|
29 |
+
|
30 |
+
inference_dataloader = DataLoader(
|
31 |
+
dataset=ds,
|
32 |
+
batch_size=args.batch_size,
|
33 |
+
collate_fn=data_collector_inference
|
34 |
+
)
|
35 |
+
|
36 |
+
logger.info("setup pipeline")
|
37 |
+
pipe = DiffusionPipeline.from_pretrained(f"{args.owner}/{data_name}_FOREcasT", trust_remote_code=True, custom_pipeline=f"{args.owner}/{data_name}_FOREcasT", MAX_DEL_SIZE=args.FOREcasT_MAX_DEL_SIZE)
|
38 |
+
pipe.FOREcasT_model.to(args.device)
|
39 |
+
|
40 |
+
for batch in tqdm(inference_dataloader):
|
41 |
+
yield pipe(batch)
|