IAMJB HF staff commited on
Commit
c4d34a4
1 Parent(s): eebdfdf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -0
README.md CHANGED
@@ -10,3 +10,45 @@ widget:
10
  - src: https://huggingface.co/IAMJB/interpret-cxr-impression-baseline/resolve/main/effusions-bibasal.jpg
11
  ---
12
  [Evaluation on chexpert-plus](https://github.com/Stanford-AIMI/chexpert-plus)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  - src: https://huggingface.co/IAMJB/interpret-cxr-impression-baseline/resolve/main/effusions-bibasal.jpg
11
  ---
12
  [Evaluation on chexpert-plus](https://github.com/Stanford-AIMI/chexpert-plus)
13
+
14
+ Usage:
15
+
16
+ ```python
17
+ import torch
18
+ from PIL import Image
19
+ from transformers import BertTokenizer, ViTImageProcessor, VisionEncoderDecoderModel, GenerationConfig
20
+ import requests
21
+
22
+ mode = "findings"
23
+ # Model
24
+ model = VisionEncoderDecoderModel.from_pretrained(f"IAMJB/chexpert-mimic-cxr-{mode}-baseline").eval()
25
+ tokenizer = BertTokenizer.from_pretrained(f"IAMJB/chexpert-mimic-cxr-{mode}-baseline")
26
+ image_processor = ViTImageProcessor.from_pretrained(f"IAMJB/chexpert-mimic-cxr-{mode}-baseline")
27
+ #
28
+ # Dataset
29
+ generation_args = {
30
+ "bos_token_id": model.config.bos_token_id,
31
+ "eos_token_id": model.config.eos_token_id,
32
+ "pad_token_id": model.config.pad_token_id,
33
+ "num_return_sequences": 1,
34
+ "max_length": 128,
35
+ "use_cache": True,
36
+ "beam_width": 2,
37
+ }
38
+ #
39
+ # Inference
40
+ refs = []
41
+ hyps = []
42
+ with torch.no_grad():
43
+ url = "https://huggingface.co/IAMJB/interpret-cxr-impression-baseline/resolve/main/effusions-bibasal.jpg"
44
+ image = Image.open(requests.get(url, stream=True).raw)
45
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
46
+ # Generate predictions
47
+ generated_ids = model.generate(
48
+ pixel_values,
49
+ generation_config=GenerationConfig(
50
+ **{**generation_args, "decoder_start_token_id": tokenizer.cls_token_id})
51
+ )
52
+ generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
53
+ print(generated_texts)
54
+ ```