tastypear commited on
Commit
d456b01
1 Parent(s): 7313605

support gemma

Browse files
Files changed (1) hide show
  1. main.py +9 -1
main.py CHANGED
@@ -48,6 +48,14 @@ def proxy():
48
  model = json_data['model']
49
  chat_api = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions"
50
 
 
 
 
 
 
 
 
 
51
  # Try to use the largest ctx
52
  if not 'max_tokens' in json_data:
53
  json_data['max_tokens'] = 2**32-1
@@ -59,7 +67,7 @@ def proxy():
59
  inputs = int(info.split("Given: ")[1].split("`")[0])
60
  json_data['max_tokens'] = max_ctx - inputs - 1
61
  except Exception as e:
62
- print(e)
63
 
64
  if not 'seed' in json_data:
65
  json_data['seed'] = random.randint(1,2**32)
 
48
  model = json_data['model']
49
  chat_api = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions"
50
 
51
+ # gemma does not support system prompt
52
+ # add system prompt before user message
53
+ if model.startswith('google/gemma') and json_data["messages"][0]['role']=='system':
54
+ system_prompt = json_data["messages"][0]['content']
55
+ first_user_content = json_data["messages"][1]['content']
56
+ json_data["messages"][1]['content'] = f'System: {system_prompt}\n\n---\n\n{first_user_content}'
57
+ json_data["messages"] = json_data["messages"][1:]
58
+
59
  # Try to use the largest ctx
60
  if not 'max_tokens' in json_data:
61
  json_data['max_tokens'] = 2**32-1
 
67
  inputs = int(info.split("Given: ")[1].split("`")[0])
68
  json_data['max_tokens'] = max_ctx - inputs - 1
69
  except Exception as e:
70
+ print(info)
71
 
72
  if not 'seed' in json_data:
73
  json_data['seed'] = random.randint(1,2**32)