gonced8 commited on
Commit
1eb1bfd
1 Parent(s): a3e88d2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -0
README.md CHANGED
@@ -1,3 +1,59 @@
1
  ---
2
  license: gpl-3.0
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: gpl-3.0
3
+ datasets:
4
+ - multi_woz_v22
5
+ language:
6
+ - en
7
+ metrics:
8
+ - bleu
9
+ - rouge
10
  ---
11
+
12
+ Pretrained model: [GODEL-v1_1-base-seq2seq](https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/)
13
+
14
+ Fine-tuning dataset: [MultiWOZ 2.2](https://github.com/budzianowski/multiwoz/tree/master/data/MultiWOZ_2.2)
15
+
16
+ # How to use:
17
+
18
+ ```python
19
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
20
+
21
+ # Load tokenizer and model
22
+ tokenizer = AutoTokenizer.from_pretrained("gonced8/godel-multiwoz")
23
+ model = AutoModelForSeq2SeqLM.from_pretrained("gonced8/godel-multiwoz")
24
+
25
+ # Encoder input
26
+ context = [
27
+ "USER: I need train reservations from norwich to cambridge",
28
+ "SYSTEM: I have 133 trains matching your request. Is there a specific day and time you would like to travel?",
29
+ "USER: I'd like to leave on Monday and arrive by 18:00.",
30
+ ]
31
+
32
+ input_text = " EOS ".join(context) + " => "
33
+
34
+ model_inputs = tokenizer(
35
+ input_text, max_length=512, truncation=True, return_tensors="pt"
36
+ )["input_ids"]
37
+
38
+ # Decoder input
39
+ answer_start = "SYSTEM: "
40
+
41
+ decoder_input_ids = tokenizer(
42
+ "<pad>" + answer_start,
43
+ max_length=256,
44
+ truncation=True,
45
+ add_special_tokens=False,
46
+ return_tensors="pt",
47
+ )["input_ids"]
48
+
49
+ # Generate
50
+ output = model.generate(
51
+ model_inputs, decoder_input_ids=decoder_input_ids, max_length=256
52
+ )
53
+ output = tokenizer.decode(
54
+ output[0], clean_up_tokenization_spaces=True, skip_special_tokens=True
55
+ )
56
+
57
+ print(output)
58
+ # SYSTEM: TR4634 arrives at 17:35. Would you like me to book that for you?
59
+ ```