valentin-ub commited on
Commit
86e4d05
·
verified ·
1 Parent(s): 04bd3a7

Update chess_board.py

Browse files
Files changed (1) hide show
  1. chess_board.py +8 -2
chess_board.py CHANGED
@@ -26,7 +26,13 @@ class Game:
26
 
27
  def compile_model(self):
28
  self.model.compile(sampler=self.sampler)
29
-
 
 
 
 
 
 
30
  def call_gemma(self, opening_move):
31
  template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
32
 
@@ -39,7 +45,7 @@ class Game:
39
  instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
40
  response="",)
41
 
42
- output = self.model.generate(prompt, max_length=256)
43
 
44
  gemma_move = output.split(' ')[-1].strip("'")
45
 
 
26
 
27
  def compile_model(self):
28
  self.model.compile(sampler=self.sampler)
29
+
30
+ @spaces.GPU
31
+ def inference_gemma(self, prompt, max_length=256):
32
+ """Inference requires GPU"""
33
+ response = self.model.generate(prompt, max_length)
34
+ return response
35
+
36
  def call_gemma(self, opening_move):
37
  template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
38
 
 
45
  instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
46
  response="",)
47
 
48
+ output = self.inference_gemma(prompt, max_length=256) #self.model.generate(prompt, max_length=256)
49
 
50
  gemma_move = output.split(' ')[-1].strip("'")
51