fffiloni KingNish commited on
Commit
4ab7724
1 Parent(s): 639cf35

Making it better and user centric (#3)

Browse files

- Added 1:1 image option. (cb28aaad9d15356c2e7bae1450bc9204f6dd9895)
- Most OP update (c5920e9e86bd07191f3eb35f95cec69113f536d9)


Co-authored-by: Nishith Jain <[email protected]>

Files changed (1) hide show
  1. app.py +85 -180
app.py CHANGED
@@ -47,184 +47,69 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
47
 
48
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
49
 
50
- prompt = "high quality"
51
- (
52
- prompt_embeds,
53
- negative_prompt_embeds,
54
- pooled_prompt_embeds,
55
- negative_pooled_prompt_embeds,
56
- ) = pipe.encode_prompt(prompt, "cuda", True)
57
 
58
-
59
-
60
- """
61
- def fill_image(image, model_selection):
62
-
63
- margin = 256
64
- overlap = 24
65
- # Open the original image
66
- source = image # Changed from image["background"] to match new input format
67
-
68
- # Calculate new output size
69
- output_size = (source.width + 2*margin, source.height + 2*margin)
70
-
71
- # Create a white background
72
- background = Image.new('RGB', output_size, (255, 255, 255))
73
-
74
- # Calculate position to paste the original image
75
- position = (margin, margin)
76
-
77
- # Paste the original image onto the white background
78
- background.paste(source, position)
79
-
80
- # Create the mask
81
- mask = Image.new('L', output_size, 255) # Start with all white
 
 
 
82
  mask_draw = ImageDraw.Draw(mask)
83
  mask_draw.rectangle([
84
- (position[0] + overlap, position[1] + overlap),
85
- (position[0] + source.width - overlap, position[1] + source.height - overlap)
86
  ], fill=0)
87
-
88
- # Prepare the image for ControlNet
89
  cnet_image = background.copy()
90
  cnet_image.paste(0, (0, 0), mask)
91
 
 
 
 
 
 
 
 
 
 
 
 
92
  for image in pipe(
93
  prompt_embeds=prompt_embeds,
94
  negative_prompt_embeds=negative_prompt_embeds,
95
  pooled_prompt_embeds=pooled_prompt_embeds,
96
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
97
  image=cnet_image,
 
98
  ):
99
- yield image, cnet_image
100
 
101
  image = image.convert("RGBA")
102
  cnet_image.paste(image, (0, 0), mask)
103
 
104
  yield background, cnet_image
105
- """
106
 
107
- @spaces.GPU
108
- def infer(image, model_selection, ratio_choice, overlap_width):
109
-
110
- source = image
111
-
112
- if ratio_choice == "16:9":
113
- target_ratio = (16, 9) # Set the new target ratio to 16:9
114
- target_width = 1280 # Adjust target width based on desired resolution
115
- overlap = overlap_width
116
- #fade_width = 24
117
- max_height = 720 # Adjust max height instead of width
118
-
119
- # Resize the image if it's taller than max_height
120
- if source.height > max_height:
121
- scale_factor = max_height / source.height
122
- new_height = max_height
123
- new_width = int(source.width * scale_factor)
124
- source = source.resize((new_width, new_height), Image.LANCZOS)
125
-
126
- # Calculate the required width for the 16:9 ratio
127
- target_width = (source.height * target_ratio[0]) // target_ratio[1]
128
-
129
- # Calculate margins (now left and right)
130
- margin_x = (target_width - source.width) // 2
131
-
132
- # Calculate new output size
133
- output_size = (target_width, source.height)
134
-
135
- # Create a white background
136
- background = Image.new('RGB', output_size, (255, 255, 255))
137
-
138
- # Calculate position to paste the original image
139
- position = (margin_x, 0)
140
-
141
- # Paste the original image onto the white background
142
- background.paste(source, position)
143
-
144
- # Create the mask
145
- mask = Image.new('L', output_size, 255) # Start with all white
146
- mask_draw = ImageDraw.Draw(mask)
147
- mask_draw.rectangle([
148
- (margin_x + overlap, overlap),
149
- (margin_x + source.width - overlap, source.height - overlap)
150
- ], fill=0)
151
-
152
- # Prepare the image for ControlNet
153
- cnet_image = background.copy()
154
- cnet_image.paste(0, (0, 0), mask)
155
-
156
- for image in pipe(
157
- prompt_embeds=prompt_embeds,
158
- negative_prompt_embeds=negative_prompt_embeds,
159
- pooled_prompt_embeds=pooled_prompt_embeds,
160
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
161
- image=cnet_image,
162
- ):
163
- yield cnet_image, image
164
-
165
- image = image.convert("RGBA")
166
- cnet_image.paste(image, (0, 0), mask)
167
-
168
- yield background, cnet_image
169
-
170
- elif ratio_choice == "9:16":
171
-
172
- target_ratio=(9, 16)
173
- target_height=1280
174
- overlap=overlap_width
175
- #fade_width=24
176
- max_width = 720
177
- # Resize the image if it's wider than max_width
178
- if source.width > max_width:
179
- scale_factor = max_width / source.width
180
- new_width = max_width
181
- new_height = int(source.height * scale_factor)
182
- source = source.resize((new_width, new_height), Image.LANCZOS)
183
-
184
- # Calculate the required height for 9:16 ratio
185
- target_height = (source.width * target_ratio[1]) // target_ratio[0]
186
-
187
- # Calculate margins (only top and bottom)
188
- margin_y = (target_height - source.height) // 2
189
-
190
- # Calculate new output size
191
- output_size = (source.width, target_height)
192
-
193
- # Create a white background
194
- background = Image.new('RGB', output_size, (255, 255, 255))
195
-
196
- # Calculate position to paste the original image
197
- position = (0, margin_y)
198
-
199
- # Paste the original image onto the white background
200
- background.paste(source, position)
201
-
202
- # Create the mask
203
- mask = Image.new('L', output_size, 255) # Start with all white
204
- mask_draw = ImageDraw.Draw(mask)
205
- mask_draw.rectangle([
206
- (overlap, margin_y + overlap),
207
- (source.width - overlap, margin_y + source.height - overlap)
208
- ], fill=0)
209
-
210
- # Prepare the image for ControlNet
211
- cnet_image = background.copy()
212
- cnet_image.paste(0, (0, 0), mask)
213
-
214
- for image in pipe(
215
- prompt_embeds=prompt_embeds,
216
- negative_prompt_embeds=negative_prompt_embeds,
217
- pooled_prompt_embeds=pooled_prompt_embeds,
218
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
219
- image=cnet_image,
220
- ):
221
- yield cnet_image, image
222
-
223
- image = image.convert("RGBA")
224
- cnet_image.paste(image, (0, 0), mask)
225
-
226
- yield background, cnet_image
227
-
228
 
229
  def clear_result():
230
  return gr.update(value=None)
@@ -243,50 +128,61 @@ title = """<h1 align="center">Diffusers Image Outpaint</h1>
243
 
244
  with gr.Blocks(css=css) as demo:
245
  with gr.Column():
246
-
247
  gr.HTML(title)
248
 
249
  with gr.Row():
250
-
251
  with gr.Column():
252
-
253
  input_image = gr.Image(
254
  type="pil",
255
  label="Input Image",
256
  sources=["upload"],
257
  )
258
-
 
 
 
 
 
 
259
  with gr.Row():
260
- ratio = gr.Radio(
261
- label="Expected ratio",
262
- choices=["9:16", "16:9"],
263
- value = "9:16"
 
 
 
 
 
 
 
 
 
264
  )
265
  model_selection = gr.Dropdown(
266
  choices=list(MODELS.keys()),
267
  value="RealVisXL V5.0 Lightning",
268
  label="Model",
269
  )
 
270
 
271
  overlap_width = gr.Slider(
272
  label="Mask overlap width",
273
- minimum = 1,
274
- maximum = 50,
275
- value = 42,
276
- step = 1
277
  )
278
-
279
- run_button = gr.Button("Generate")
280
 
281
  gr.Examples(
282
- examples = [
283
- ["./examples/example_1.webp", "RealVisXL V5.0 Lightning", "16:9"],
284
- ["./examples/example_2.jpg", "RealVisXL V5.0 Lightning", "16:9"],
285
- ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", "9:16"]
286
  ],
287
- inputs = [input_image, model_selection, ratio]
288
  )
289
-
290
  with gr.Column():
291
  result = ImageSlider(
292
  interactive=False,
@@ -299,9 +195,18 @@ with gr.Blocks(css=css) as demo:
299
  outputs=result,
300
  ).then(
301
  fn=infer,
302
- inputs=[input_image, model_selection, ratio, overlap_width],
303
  outputs=result,
304
  )
305
 
 
 
 
 
 
 
 
 
 
306
 
307
- demo.launch(share=False)
 
47
 
48
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
49
 
 
 
 
 
 
 
 
50
 
51
+ @spaces.GPU
52
+ def infer(image, model_selection, width, height, overlap_width, num_inference_steps, prompt_input=None):
53
+ source = image
54
+ target_size = (width, height)
55
+ target_ratio = (width, height) # Calculate aspect ratio from width and height
56
+ overlap = overlap_width
57
+
58
+ # Upscale if source is smaller than target in both dimensions
59
+ if source.width < target_size[0] and source.height < target_size[1]:
60
+ scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
61
+ new_width = int(source.width * scale_factor)
62
+ new_height = int(source.height * scale_factor)
63
+ source = source.resize((new_width, new_height), Image.LANCZOS)
64
+
65
+ if source.width > target_size[0] or source.height > target_size[1]:
66
+ scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
67
+ new_width = int(source.width * scale_factor)
68
+ new_height = int(source.height * scale_factor)
69
+ source = source.resize((new_width, new_height), Image.LANCZOS)
70
+
71
+ margin_x = (target_size[0] - source.width) // 2
72
+ margin_y = (target_size[1] - source.height) // 2
73
+
74
+ background = Image.new('RGB', target_size, (255, 255, 255))
75
+ background.paste(source, (margin_x, margin_y))
76
+
77
+ mask = Image.new('L', target_size, 255)
78
  mask_draw = ImageDraw.Draw(mask)
79
  mask_draw.rectangle([
80
+ (margin_x + overlap, margin_y + overlap),
81
+ (margin_x + source.width - overlap, margin_y + source.height - overlap)
82
  ], fill=0)
83
+
 
84
  cnet_image = background.copy()
85
  cnet_image.paste(0, (0, 0), mask)
86
 
87
+ final_prompt = "high quality"
88
+ if prompt_input.strip() != "":
89
+ final_prompt += ", " + prompt_input
90
+
91
+ (
92
+ prompt_embeds,
93
+ negative_prompt_embeds,
94
+ pooled_prompt_embeds,
95
+ negative_pooled_prompt_embeds,
96
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
97
+
98
  for image in pipe(
99
  prompt_embeds=prompt_embeds,
100
  negative_prompt_embeds=negative_prompt_embeds,
101
  pooled_prompt_embeds=pooled_prompt_embeds,
102
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
103
  image=cnet_image,
104
+ num_inference_steps=num_inference_steps
105
  ):
106
+ yield cnet_image, image
107
 
108
  image = image.convert("RGBA")
109
  cnet_image.paste(image, (0, 0), mask)
110
 
111
  yield background, cnet_image
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def clear_result():
115
  return gr.update(value=None)
 
128
 
129
  with gr.Blocks(css=css) as demo:
130
  with gr.Column():
 
131
  gr.HTML(title)
132
 
133
  with gr.Row():
 
134
  with gr.Column():
 
135
  input_image = gr.Image(
136
  type="pil",
137
  label="Input Image",
138
  sources=["upload"],
139
  )
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=2):
143
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
144
+ with gr.Column(scale=1):
145
+ run_button = gr.Button("Generate")
146
+
147
  with gr.Row():
148
+ width_slider = gr.Slider(
149
+ label="Width",
150
+ minimum=720,
151
+ maximum=1440,
152
+ step=8,
153
+ value=1440, # Set a default value
154
+ )
155
+ height_slider = gr.Slider(
156
+ label="Height",
157
+ minimum=720,
158
+ maximum=1440,
159
+ step=8,
160
+ value=1024, # Set a default value
161
  )
162
  model_selection = gr.Dropdown(
163
  choices=list(MODELS.keys()),
164
  value="RealVisXL V5.0 Lightning",
165
  label="Model",
166
  )
167
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8 )
168
 
169
  overlap_width = gr.Slider(
170
  label="Mask overlap width",
171
+ minimum=1,
172
+ maximum=50,
173
+ value=42,
174
+ step=1
175
  )
 
 
176
 
177
  gr.Examples(
178
+ examples=[
179
+ ["./examples/example_1.webp", "RealVisXL V5.0 Lightning", 1280, 720],
180
+ ["./examples/example_2.jpg", "RealVisXL V5.0 Lightning", 720, 1280],
181
+ ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024],
182
  ],
183
+ inputs=[input_image, model_selection, width_slider, height_slider],
184
  )
185
+
186
  with gr.Column():
187
  result = ImageSlider(
188
  interactive=False,
 
195
  outputs=result,
196
  ).then(
197
  fn=infer,
198
+ inputs=[input_image, model_selection, width_slider, height_slider, overlap_width, num_inference_steps, prompt_input],
199
  outputs=result,
200
  )
201
 
202
+ prompt_input.submit(
203
+ fn=clear_result,
204
+ inputs=None,
205
+ outputs=result,
206
+ ).then(
207
+ fn=infer,
208
+ inputs=[input_image, model_selection, width_slider, height_slider, overlap_width, num_inference_steps, prompt_input],
209
+ outputs=result,
210
+ )
211
 
212
+ demo.launch(share=False)