import pyclamd import urllib.response, requests from picklescan.scanner import ( scan_url, scan_file_path, ScanResult, SafetyLevel ) # def scan_file(file_path: str): # ret = scan_pickle_bytes(io.BytesIO(pickle.dumps(file_path)), "file.pkl") # print(ret) def scan_file(file_path: str): if file_path.startswith("http"): scan_result: ScanResult = scan_url(file_path) else: scan_result: ScanResult = scan_file_path(file_path) globalImports = list(map(lambda x: fmt_import(x.module, x.name), scan_result.globals)) dangerousImports = list(map(lambda x: fmt_import(x.module, x.name), filter(lambda x: x.safety == SafetyLevel.Dangerous, scan_result.globals))) if len(dangerousImports) > 0: picklescanExitCode = 1 else: picklescanExitCode = 0 return { 'url': file_path, 'fileExists': True, 'picklescanExitCode': picklescanExitCode, 'picklescanGlobalImports': globalImports, 'picklescanDangerousImports': dangerousImports, # 'clamscanExitCode': ScanExitCode, # 'clamscanOutput': string, # hashes: Record < ModelHashType, string >; # conversions: Record < 'safetensors' | 'ckpt', ConversionResult >; } def init_clamd(): clamd = pyclamd.ClamdUnixSocket() return clamd headers = { "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36" } def clamd_file(file_path: str, clamd): if file_path.startswith("http"): tmp_path = f'/tmp/clamd_{file_path.split("/")[-1].split("?")[0]}' print("tmp_path ", tmp_path) resp = requests.get(file_path, headers=headers).content with open(tmp_path, "wb") as f: f.write(resp) # urllib.request.urlretrieve(file_path, tmp_path) ret = clamd.scan_file(tmp_path) if ret is None: return { 'clamscanExitCode': 0, 'clamscanOutput': "No virus found", } elif tmp_path in ret and len(tmp_path) > 0: return { 'clamscanExitCode': 1, 'clamscanOutput': ' '.join(ret[tmp_path]), } def fmt_import(module: str, name: str): return f"from ${module} import ${name}", if __name__ == "__main__": detail = scan_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt") clamd_detail = clamd_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt") print(detail) print(clamd_detail) # ScanResult( # globals=[Global(module='torch', name='FloatStorage', safety= < SafetyLevel.Innocuous: 'innocuous' >), # Global(module='collections', name='OrderedDict', safety= < SafetyLevel.Innocuous: 'innocuous' >), # Global(module='torch._utils', name='_rebuild_tensor_v2',safety= < SafetyLevel.Innocuous: 'innocuous' >)], # scanned_files = 1, issues_count = 0, infected_files = 0, scan_err = False)