Abhinowww commited on
Commit
90b9490
·
1 Parent(s): d28b09a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -13
app.py CHANGED
@@ -75,9 +75,13 @@ original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
  mean_latent = original_generator.mean_latent(10000)
76
  # print(ckpt.keys())
77
 
78
- generatorjoker = deepcopy(original_generator)
79
 
80
- generatorvoldemort = deepcopy(original_generator)
 
 
 
 
81
 
82
  generatorpushpa = deepcopy(original_generator)
83
 
@@ -103,15 +107,21 @@ transform = transforms.Compose(
103
  )
104
 
105
 
106
- modeljoker = hf_hub_download(repo_id="Abhinowww/Capstone", filename="JokerEightHundredFalse.pt")
107
- ckptjoker = torch.load(modeljoker, map_location=lambda storage, loc: storage)
108
- generatorjoker.load_state_dict(ckptjoker, strict=False)
109
 
 
 
 
110
 
111
- modelvoldemort = hf_hub_download(repo_id="Abhinowww/Capstone", filename="VoldemortFourHundredFalse.pt")
112
- ckptvoldemort = torch.load(modelvoldemort, map_location=lambda storage, loc: storage)
113
- generatorvoldemort.load_state_dict(ckptvoldemort, strict=False)
114
 
 
 
 
115
 
116
  modelpushpa = hf_hub_download(repo_id="Abhinowww/Capstone", filename="PushpaFourHundredFalse.pt")
117
  ckptpushpa = torch.load(modelpushpa, map_location=lambda storage, loc: storage)
@@ -121,11 +131,11 @@ modelgiga = hf_hub_download(repo_id="Abhinowww/Capstone", filename="GigachadFour
121
  ckptgiga = torch.load(modelgiga, map_location=lambda storage, loc: storage)
122
  generatorgiga.load_state_dict(ckptgiga, strict=False)
123
 
124
- modelsketchtrue = hf_hub_download(repo_id="Abhinowww/Capstone", filename="SketchFourHundredTrue.pt")
125
  ckptsketchtrue = torch.load(modelsketchtrue, map_location=lambda storage, loc: storage)
126
  generatorsketchtrue.load_state_dict(ckptsketchtrue, strict=False)
127
 
128
- modelsketchfalse = hf_hub_download(repo_id="Abhinowww/Capstone", filename="SketchFourHundredFalse.pt")
129
  ckptsketchfalse = torch.load(modelsketchfalse, map_location=lambda storage, loc: storage)
130
  generatorsketchfalse.load_state_dict(ckptsketchfalse, strict=False)
131
 
@@ -138,10 +148,16 @@ def inference(img, model):
138
  my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
139
  if model == 'Joker':
140
  with torch.no_grad():
141
- my_sample = generatorjoker(my_w, input_is_latent=True)
 
 
 
142
  elif model == 'Voldemort':
143
  with torch.no_grad():
144
- my_sample = generatorvoldemort(my_w, input_is_latent=True)
 
 
 
145
  elif model == 'Pushpa':
146
  with torch.no_grad():
147
  my_sample = generatorpushpa(my_w, input_is_latent=True)
@@ -180,4 +196,4 @@ description = "Capstone Project. To use it, simply upload your image, or click o
180
  # css_code='body{background-image:url("https://picsum.photos/seed/picsum/200/300");}'
181
  # gr.Interface(lambda x:x, "textbox", "textbox", css=css_code).launch(debug=True)
182
 
183
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['Joker', 'Voldemort', 'Pushpa', 'Gigachad', 'Sketch', 'Sketch Preserve'], type="value", default='Joker', label="Model")], gr.outputs.Image(type="pil"),title=title,description=description,allow_flagging=False,allow_screenshot=False).launch()
 
75
  mean_latent = original_generator.mean_latent(10000)
76
  # print(ckpt.keys())
77
 
78
+ generatorjokerfalse = deepcopy(original_generator)
79
 
80
+ generatorjokertrue = deepcopy(original_generator)
81
+
82
+ generatorvoldemortfalse = deepcopy(original_generator)
83
+
84
+ generatorvoldemorttrue = deepcopy(original_generator)
85
 
86
  generatorpushpa = deepcopy(original_generator)
87
 
 
107
  )
108
 
109
 
110
+ modeljokerfalse = hf_hub_download(repo_id="Abhinowww/Capstone", filename="JokerEightHundredFalse.pt")
111
+ ckptjokerfalse = torch.load(modeljokerfalse, map_location=lambda storage, loc: storage)
112
+ generatorjokerfalse.load_state_dict(ckptjokerfalse, strict=False)
113
 
114
+ modeljokertrue = hf_hub_download(repo_id="Abhinowww/Capstone", filename="JokerTwoHundredFiftyTrue.pt")
115
+ ckptjokertrue = torch.load(modeljokertrue, map_location=lambda storage, loc: storage)
116
+ generatorjokertrue.load_state_dict(ckptjokertrue, strict=False)
117
 
118
+ modelvoldemortfalse = hf_hub_download(repo_id="Abhinowww/Capstone", filename="VoldemortFourHundredFalse.pt")
119
+ ckptvoldemortfalse = torch.load(modelvoldemortfalse, map_location=lambda storage, loc: storage)
120
+ generatorvoldemortfalse.load_state_dict(ckptvoldemortfalse, strict=False)
121
 
122
+ modelvoldemorttrue = hf_hub_download(repo_id="Abhinowww/Capstone", filename="VoldemortThreeHundredTrue.pt")
123
+ ckptvoldemorttrue = torch.load(modelvoldemorttrue, map_location=lambda storage, loc: storage)
124
+ generatorvoldemorttrue.load_state_dict(ckptvoldemorttrue, strict=False)
125
 
126
  modelpushpa = hf_hub_download(repo_id="Abhinowww/Capstone", filename="PushpaFourHundredFalse.pt")
127
  ckptpushpa = torch.load(modelpushpa, map_location=lambda storage, loc: storage)
 
131
  ckptgiga = torch.load(modelgiga, map_location=lambda storage, loc: storage)
132
  generatorgiga.load_state_dict(ckptgiga, strict=False)
133
 
134
+ modelsketchtrue = hf_hub_download(repo_id="Abhinowww/Capstone", filename="OGSketchFourHundredTrue.pt")
135
  ckptsketchtrue = torch.load(modelsketchtrue, map_location=lambda storage, loc: storage)
136
  generatorsketchtrue.load_state_dict(ckptsketchtrue, strict=False)
137
 
138
+ modelsketchfalse = hf_hub_download(repo_id="Abhinowww/Capstone", filename="OGSketchFourHundredFalse.pt")
139
  ckptsketchfalse = torch.load(modelsketchfalse, map_location=lambda storage, loc: storage)
140
  generatorsketchfalse.load_state_dict(ckptsketchfalse, strict=False)
141
 
 
148
  my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
149
  if model == 'Joker':
150
  with torch.no_grad():
151
+ my_sample = generatorjokerfalse(my_w, input_is_latent=True)
152
+ elif model == 'Joker Preserve':
153
+ with torch.no_grad():
154
+ my_sample = generatorjokertrue(my_w, input_is_latent=True)
155
  elif model == 'Voldemort':
156
  with torch.no_grad():
157
+ my_sample = generatorvoldemortfalse(my_w, input_is_latent=True)
158
+ elif model == 'Voldemort Preserve':
159
+ with torch.no_grad():
160
+ my_sample = generatorvoldemorttrue(my_w, input_is_latent=True)
161
  elif model == 'Pushpa':
162
  with torch.no_grad():
163
  my_sample = generatorpushpa(my_w, input_is_latent=True)
 
196
  # css_code='body{background-image:url("https://picsum.photos/seed/picsum/200/300");}'
197
  # gr.Interface(lambda x:x, "textbox", "textbox", css=css_code).launch(debug=True)
198
 
199
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['Joker', 'Joker Preserve', 'Voldemort', 'Voldemort Preserve', 'Pushpa', 'Gigachad', 'Sketch', 'Sketch Preserve'], type="value", default='Joker', label="Model")], gr.outputs.Image(type="pil"),title=title,description=description,allow_flagging=False,allow_screenshot=False).launch()