cfli commited on
Commit
05926ac
1 Parent(s): 7346880

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -0
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoTokenizer, LlamaModel
3
+
4
+ def get_query_inputs(query, tokenizer, max_length=512):
5
+ prefix = '"'
6
+ suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
7
+ prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
8
+ suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
9
+ inputs = tokenizer(query,
10
+ return_tensors=None,
11
+ max_length=max_length,
12
+ truncation=True,
13
+ add_special_tokens=False)
14
+ inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
15
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
16
+ return tokenizer.pad(
17
+ [inputs],
18
+ padding=True,
19
+ max_length=max_length,
20
+ pad_to_multiple_of=8,
21
+ return_tensors='pt',
22
+ )
23
+
24
+ def get_passage_inputs(passage, tokenizer, max_length=512):
25
+ prefix = '"'
26
+ suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
27
+ prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
28
+ suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
29
+ inputs = tokenizer(passage,
30
+ return_tensors=None,
31
+ max_length=max_length,
32
+ truncation=True,
33
+ add_special_tokens=False)
34
+ inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
35
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
36
+ return tokenizer.pad(
37
+ [inputs],
38
+ padding=True,
39
+ max_length=max_length,
40
+ pad_to_multiple_of=8,
41
+ return_tensors='pt',
42
+ )
43
+
44
+ # Load the tokenizer and model
45
+ tokenizer = AutoTokenizer.from_pretrained('cfli/LLARA-beir')
46
+ model = AutoModel.from_pretrained('cfli/LLARA-beir')
47
+
48
+ # Define query and passage inputs
49
+ query = "What is llama?"
50
+ title = "Llama"
51
+ passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
52
+ query_input = get_query_inputs(query, tokenizer)
53
+ passage_input = get_passage_inputs(passage, tokenizer)
54
+
55
+
56
+ with torch.no_grad():
57
+ # compute query embedding
58
+ query_outputs = model(**query_input, return_dict=True, output_hidden_states=True)
59
+ query_embedding = query_outputs.hidden_states[-1][:, -8:, :]
60
+ query_embedding = torch.mean(query_embedding, dim=1)
61
+ query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1)
62
+
63
+ # compute passage embedding
64
+ passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True)
65
+ passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :]
66
+ passage_embeddings = torch.mean(passage_embeddings, dim=1)
67
+ passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1)
68
+
69
+ # compute similarity score
70
+ score = query_embedding @ passage_embeddings.T
71
+ print(score)