dblasko commited on
Commit
7d75e22
1 Parent(s): 9b9b1dc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -11,6 +11,52 @@ tags:
11
  # MIRNet low-light image enhancement
12
  MIRNet-based low-light image enhancer specialized on restoring dark images from events (concerts, parties, clubs...).
13
 
 
14
  Documentation about pre-training, fine-tuning, model architecture, usage and all source code used for building and inference can be found in the [GitHub repository of the project](https://github.com/dblasko/low-light-event-img-enhancer/).
15
  This currently stores the PyTorch model weights and model definition, a HuggingFace pipeline will be implemented in the future.
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # MIRNet low-light image enhancement
12
  MIRNet-based low-light image enhancer specialized on restoring dark images from events (concerts, parties, clubs...).
13
 
14
+ ## Project source-code and further documentation
15
  Documentation about pre-training, fine-tuning, model architecture, usage and all source code used for building and inference can be found in the [GitHub repository of the project](https://github.com/dblasko/low-light-event-img-enhancer/).
16
  This currently stores the PyTorch model weights and model definition, a HuggingFace pipeline will be implemented in the future.
17
 
18
+ ## Using the model
19
+ To use the model, you need to have the `model` folder, that you can dowload from this repository as well as on [GitHub](https://github.com/dblasko/low-light-event-img-enhancer/), present in your project folder.
20
+
21
+ Then, the following code can be used to download the model weights from HuggingFace and load them in PyTorch for downstream use of the model:
22
+ ```python
23
+ import torch
24
+ import torchvision.transforms as T
25
+ from PIL import Image
26
+ from huggingface_hub import hf_hub_url, cached_download
27
+ from model.MIRNet.model import MIRNet
28
+
29
+ device = (
30
+ torch.device("cuda")
31
+ if torch.cuda.is_available()
32
+ else torch.device("mps")
33
+ if torch.backends.mps.is_available()
34
+ else torch.device("cpu")
35
+ )
36
+
37
+ # Download the model weights from the Hugging Face Hub
38
+ model_url = hf_hub_url(
39
+ repo_id="dblasko/mirnet-low-light-img-enhancement", filename="mirnet_finetuned.pth" # or mirnet_pretrained.pth
40
+ )
41
+ model_path = cached_download(model_url)
42
+
43
+ # Load the model
44
+ model = MIRNet().to(device)
45
+ model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])
46
+
47
+ # Use the model, for example for inference on an image
48
+ model.eval()
49
+ with torch.no_grad():
50
+ img = Image.open("image_path.png").convert("RGB")
51
+ img_tensor = T.Compose(
52
+ [
53
+ T.Resize(400), # Adjust image resizing depending on hardware
54
+ T.ToTensor(),
55
+ T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
56
+ ]
57
+ )(img).unsqueeze(0)
58
+ img_tensor = img_tensor.to(device)
59
+
60
+ output = model(img_tensor)
61
+
62
+ ```