File size: 1,263 Bytes
bd9da36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import subprocess
from typing import List

import pytest


@pytest.fixture
def download_weights(output_directory: str = "artifacts") -> None:
    base_url: str = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
    file_names: List[str] = [
        "sam2_hiera_tiny.pt",
        "sam2_hiera_small.pt",
        "sam2_hiera_base_plus.pt",
        "sam2_hiera_large.pt",
    ]

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for file_name in file_names:
        file_path = os.path.join(output_directory, file_name)
        if not os.path.exists(file_path):
            url = f"{base_url}{file_name}"
            command = ["wget", url, "-P", output_directory]
            try:
                result = subprocess.run(
                    command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
                )
                print(f"Download of {file_name} completed successfully.")
                print(result.stdout.decode())
            except subprocess.CalledProcessError as e:
                print(f"An error occurred during the download of {file_name}.")
                print(e.stderr.decode())
        else:
            print(f"{file_name} already exists. Skipping download.")