ljw20180420 commited on
Commit
32ba7fe
·
verified ·
1 Parent(s): 7227e45

Upload AI_models/FOREcasT/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. AI_models/FOREcasT/inference.py +41 -0
AI_models/FOREcasT/inference.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets import load_dataset, Features, Value
3
+ from torch.utils.data import DataLoader
4
+ from diffusers import DiffusionPipeline
5
+ from tqdm import tqdm
6
+ from ..config import get_config, get_logger
7
+ from .load_data import data_collector
8
+
9
+ args = get_config(config_file="config_FOREcasT.ini")
10
+ logger = get_logger(args)
11
+
12
+ @torch.no_grad()
13
+ def data_collector_inference(examples):
14
+ for example in examples:
15
+ ref, cut = example["ref"], example["cut"]
16
+ assert len(ref) >= args.ref1len and len(ref) >= args.ref2len, f"ref of length {len(ref)} is too short, please decrease ref1len={args.ref1len} and/or ref2len={args.ref2len} in inference arguments"
17
+ assert cut <= args.ref1len and len(ref) - cut <= args.ref2len, f"ref1len={args.ref1len} and/or ref2len={args.ref2len} is too short, please increase them to cover cut site {cut}"
18
+ assert cut >= args.FOREcasT_MAX_DEL_SIZE, f"ref upstream to cut ({cut}) is less than FOREcasT_MAX_DEL_SIZE ({args.FOREcasT_MAX_DEL_SIZE}), extend ref to upstream"
19
+ assert len(ref) - cut >= args.FOREcasT_MAX_DEL_SIZE, f"ref downstream to cut ({len(ref) - cut}) is less than FOREcasT_MAX_DEL_SIZE ({args.FOREcasT_MAX_DEL_SIZE}), extend ref to downstream"
20
+ return data_collector(examples, output_count=False)
21
+
22
+ @torch.no_grad()
23
+ def inference(data_name=args.data_name, data_files="inference.json.gz"):
24
+ logger.info("load inference data")
25
+ ds = load_dataset('json', data_files=data_files, features=Features({
26
+ 'ref': Value('string'),
27
+ 'cut': Value('int16')
28
+ }))["train"]
29
+
30
+ inference_dataloader = DataLoader(
31
+ dataset=ds,
32
+ batch_size=args.batch_size,
33
+ collate_fn=data_collector_inference
34
+ )
35
+
36
+ logger.info("setup pipeline")
37
+ pipe = DiffusionPipeline.from_pretrained(f"{args.owner}/{data_name}_FOREcasT", trust_remote_code=True, custom_pipeline=f"{args.owner}/{data_name}_FOREcasT", MAX_DEL_SIZE=args.FOREcasT_MAX_DEL_SIZE)
38
+ pipe.FOREcasT_model.to(args.device)
39
+
40
+ for batch in tqdm(inference_dataloader):
41
+ yield pipe(batch)