gugarosa commited on
Commit
769684a
1 Parent(s): 1f890f7

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -4
README.md CHANGED
@@ -74,9 +74,9 @@ The model is licensed under the [Research License](https://huggingface.co/micros
74
  import torch
75
  from transformers import AutoModelForCausalLM, AutoTokenizer
76
 
77
- torch.set_default_device('cuda')
78
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1", trust_remote_code=True, torch_dtype="auto")
79
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1", trust_remote_code=True, torch_dtype="auto")
80
  inputs = tokenizer('''def print_prime(n):
81
  """
82
  Print all primes between 1 and n
@@ -87,8 +87,14 @@ text = tokenizer.batch_decode(outputs)[0]
87
  print(text)
88
  ```
89
 
 
 
 
 
 
 
90
  **Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1).
91
- Furthermore, in the forward pass of the model, we currently do not support outputing hidden states or attention values, or using custom input embeddings (instead of the model's).
92
 
93
 
94
  ### Citation
 
74
  import torch
75
  from transformers import AutoModelForCausalLM, AutoTokenizer
76
 
77
+ torch.set_default_device("cuda")
78
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1", trust_remote_code=True)
79
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1", trust_remote_code=True)
80
  inputs = tokenizer('''def print_prime(n):
81
  """
82
  Print all primes between 1 and n
 
87
  print(text)
88
  ```
89
 
90
+ If you need to use the model in a lower precision (e.g., FP16), please wrap the model's forward pass with `torch.autocast()`, as follows:
91
+ ```python
92
+ with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
93
+ outputs = model.generate(**inputs, max_length=200)
94
+ ```
95
+
96
  **Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1).
97
+ Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
98
 
99
 
100
  ### Citation