makaveli commited on
Commit
855f306
1 Parent(s): 91de2f7

Update llm_service.py to tensorrt-llm 0.7.1

Browse files
Files changed (1) hide show
  1. llm_service.py +3 -3
llm_service.py CHANGED
@@ -14,7 +14,7 @@ if PYTHON_BINDINGS:
14
 
15
 
16
  def read_model_name(engine_dir: str):
17
- engine_version = tensorrt_llm.builder.get_engine_version(engine_dir)
18
 
19
  with open(Path(engine_dir) / "config.json", 'r') as f:
20
  config = json.load(f)
@@ -128,7 +128,7 @@ class MistralTensorRTLLM:
128
  batch_input_ids.append(input_ids)
129
 
130
  batch_input_ids = [
131
- torch.tensor(x, dtype=torch.int32).unsqueeze(0) for x in batch_input_ids
132
  ]
133
  return batch_input_ids
134
 
@@ -188,7 +188,7 @@ class MistralTensorRTLLM:
188
  pad_id=None,
189
  )
190
 
191
- input_lengths = [x.size(1) for x in batch_input_ids]
192
  with torch.no_grad():
193
  outputs = self.runner.generate(
194
  batch_input_ids,
 
14
 
15
 
16
  def read_model_name(engine_dir: str):
17
+ engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir)
18
 
19
  with open(Path(engine_dir) / "config.json", 'r') as f:
20
  config = json.load(f)
 
128
  batch_input_ids.append(input_ids)
129
 
130
  batch_input_ids = [
131
+ torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
132
  ]
133
  return batch_input_ids
134
 
 
188
  pad_id=None,
189
  )
190
 
191
+ input_lengths = [x.size(0) for x in batch_input_ids]
192
  with torch.no_grad():
193
  outputs = self.runner.generate(
194
  batch_input_ids,