woods-today commited on
Commit
50a1127
·
1 Parent(s): de8b3c6

Working on it

Browse files
Files changed (1) hide show
  1. routers/training.py +25 -8
routers/training.py CHANGED
@@ -6,12 +6,14 @@ from routers.donut_evaluate import run_evaluate_donut
6
  from routers.donut_training import run_training_donut
7
  import utils
8
  import torch
9
- from diffusers import StableDiffusionPipeline
 
 
10
 
11
- from diffusers import DiffusionPipeline
12
 
13
- pipe = DiffusionPipeline.from_pretrained("radames/stable-diffusion-v1-5-img2img")
14
- # pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)
15
  pipe = pipe.to("cuda")
16
 
17
 
@@ -19,7 +21,22 @@ router = APIRouter()
19
 
20
  @router.get("/hi")
21
  async def hifunction():
22
- prompt = "a photograph of an astronaut riding a horse"
23
- image = pipe(prompt).images[0]
24
- print(image)
25
- return ["done"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from routers.donut_training import run_training_donut
7
  import utils
8
  import torch
9
+ import requests
10
+ from PIL import Image
11
+ from io import BytesIO
12
 
13
+ from diffusers import StableDiffusionImg2ImgPipeline
14
 
15
+ model_id_or_path = "runwayml/stable-diffusion-v1-5"
16
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
17
  pipe = pipe.to("cuda")
18
 
19
 
 
21
 
22
  @router.get("/hi")
23
  async def hifunction():
24
+
25
+ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
26
+ response = requests.get(url)
27
+ init_image = Image.open(BytesIO(response.content)).convert("RGB")
28
+ init_image = init_image.resize((768, 512))
29
+ prompt = "A fantasy landscape, trending on artstation"
30
+ images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
31
+ print(images)
32
+ print(images[0])
33
+
34
+ buffered = BytesIO()
35
+ image.save(buffered, format="JPEG")
36
+ img_str = base64.b64encode(buffered.getvalue())
37
+
38
+ # images[0].save("fantasy_landscape.png")
39
+
40
+ return {
41
+ "image": img_str
42
+ }