LLMCalc / utils /filter_utils.py
carbonnnnn's picture
working first draft
68e6513
raw
history blame
805 Bytes
# Utility functions for filtering the dataframe
def filter_cols(df):
df = df[[
'model_name',
'input_price',
'output_price',
'release_date',
'context_size',
'average_clemscore',
'average_latency',
'parameter_size',
]]
return df
def filter(df, language_list, clemscore, input_price, output_price):
df = df[df['languages'].apply(lambda x: all(lang in x for lang in language_list))]
df = df[(df['average_clemscore'] >= clemscore[0]) & (df['average_clemscore'] <= clemscore[1])]
df = df[(df['input_price'] >= input_price[0]) & (df['input_price'] <= input_price[1])]
df = df[(df['output_price'] >= output_price[0]) & (df['output_price'] <= output_price[1])]
df = filter_cols(df)
return df # Return the filtered dataframe