alexkueck commited on
Commit
57e67fc
·
1 Parent(s): ef63f44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -59
app.py CHANGED
@@ -37,65 +37,6 @@ def group_texts(examples):
37
  return result
38
 
39
 
40
- #neues Model testen nach dem Training
41
- ########################################################################
42
- #Zm Test einen Text zu generieren...
43
- def predict(text,
44
- history,
45
- top_p=0.3,
46
- temperature=0.9,
47
- max_length_tokens=1024,
48
- max_context_length_tokens=2048,):
49
- if text=="":
50
- return
51
- try:
52
- model
53
- except:
54
- return print("fehler model")
55
-
56
- inputs = generate_prompt_with_history(text,history, tokenizer,max_length=max_context_length_tokens)
57
- if inputs is None:
58
- return print("Fehler inputs")
59
- else:
60
- prompt,inputs=inputs
61
- begin_length = len(prompt)
62
-
63
- input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
64
- torch.cuda.empty_cache()
65
-
66
- #torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation
67
- #hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht!
68
- print("torch.no_grad")
69
- with torch.no_grad():
70
- ausgabe = ""
71
- #die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen
72
- for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
73
- if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
74
- if "[|Human|]" in x:
75
- x = x[:x.index("[|Human|]")].strip()
76
- if "[|AI|]" in x:
77
- x = x[:x.index("[|AI|]")].strip()
78
- x = x.strip()
79
- ausgabe = ausgabe + x
80
- # a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
81
- print("Erzeuge")
82
- yield print(Ausgabe)
83
- if shared_state.interrupted:
84
- shared_state.recover()
85
- try:
86
- print("Erfolg")
87
- return ausgabe
88
- except:
89
- pass
90
- del input_ids
91
- gc.collect()
92
- torch.cuda.empty_cache()
93
-
94
- try:
95
- print("erfolg")
96
- return ausgabe
97
- except:
98
- pass
99
 
100
 
101
 
 
37
  return result
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42