|
|
|
@ -0,0 +1,158 @@
|
|
|
|
|
import copy
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import gradio as gr
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
|
from modules import scripts, shared
|
|
|
|
|
from modules.processing import StableDiffusionProcessing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlessVaeScript(scripts.Script):
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.conv_out_weight_original = None
|
|
|
|
|
self.conv_out_bias_original = None
|
|
|
|
|
|
|
|
|
|
def title(self):
|
|
|
|
|
return 'Bless VAE'
|
|
|
|
|
|
|
|
|
|
def show(self, is_img2img):
|
|
|
|
|
return scripts.AlwaysVisible
|
|
|
|
|
|
|
|
|
|
def ui(self, is_img2img):
|
|
|
|
|
enabled, contrast_op, contrast_value, brightness_op, brightness_value = self._create_ui()
|
|
|
|
|
|
|
|
|
|
self.infotext_fields = (
|
|
|
|
|
(enabled, lambda x: gr.Checkbox.update(value='Bless VAE enabled' in x)),
|
|
|
|
|
(contrast_op, 'Bless VAE contrast op'),
|
|
|
|
|
(contrast_value, 'Bless VAE contrast value'),
|
|
|
|
|
(brightness_op, 'Bless VAE brightness op'),
|
|
|
|
|
(brightness_value, 'Bless VAE contrast value'),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
enabled,
|
|
|
|
|
contrast_op,
|
|
|
|
|
contrast_value,
|
|
|
|
|
brightness_op,
|
|
|
|
|
brightness_value
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def process(
|
|
|
|
|
self,
|
|
|
|
|
p: StableDiffusionProcessing,
|
|
|
|
|
bless_enabled: bool,
|
|
|
|
|
bless_contrast_op: str,
|
|
|
|
|
bless_contrast_value: float,
|
|
|
|
|
bless_brightness_op: str,
|
|
|
|
|
bless_brightness_value: float,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
if bless_enabled and p.sd_model:
|
|
|
|
|
self.conv_out_weight_original = copy.deepcopy(p.sd_model.first_stage_model.decoder.conv_out.weight)
|
|
|
|
|
self.conv_out_bias_original = copy.deepcopy(p.sd_model.first_stage_model.decoder.conv_out.bias)
|
|
|
|
|
|
|
|
|
|
contrast_op = getattr(p, 'bless_contast_op', bless_contrast_op)
|
|
|
|
|
contrast_value = getattr(p, 'bless_contrast_value', bless_contrast_value)
|
|
|
|
|
brightness_op = getattr(p, 'bless_brightness_op', bless_brightness_op)
|
|
|
|
|
brightness_value = getattr(p, 'bless_brightness_value', bless_brightness_value)
|
|
|
|
|
|
|
|
|
|
p.extra_generation_params["Bless VAE enabled"] = True
|
|
|
|
|
p.extra_generation_params["Bless VAE contrast op"] = contrast_op
|
|
|
|
|
p.extra_generation_params["Bless VAE contrast value"] = contrast_value
|
|
|
|
|
p.extra_generation_params["Bless VAE brightness op"] = brightness_op
|
|
|
|
|
p.extra_generation_params["Bless VAE brightness value"] = brightness_value
|
|
|
|
|
|
|
|
|
|
if contrast_op == 'Add':
|
|
|
|
|
p.sd_model.first_stage_model.decoder.conv_out.weight = nn.Parameter(
|
|
|
|
|
self.conv_out_weight_original + contrast_value
|
|
|
|
|
)
|
|
|
|
|
elif contrast_op == 'Multiply':
|
|
|
|
|
p.sd_model.first_stage_model.decoder.conv_out.weight = nn.Parameter(
|
|
|
|
|
self.conv_out_weight_original * contrast_value
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if brightness_op == 'Add':
|
|
|
|
|
p.sd_model.first_stage_model.decoder.conv_out.bias = nn.Parameter(
|
|
|
|
|
self.conv_out_bias_original + brightness_value
|
|
|
|
|
)
|
|
|
|
|
elif brightness_op == 'Multiply':
|
|
|
|
|
p.sd_model.first_stage_model.decoder.conv_out.bias = nn.Parameter(
|
|
|
|
|
self.conv_out_bias_original * brightness_value
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def postprocess(
|
|
|
|
|
self,
|
|
|
|
|
p: StableDiffusionProcessing,
|
|
|
|
|
bless_enabled: bool,
|
|
|
|
|
bless_contrast_op: str,
|
|
|
|
|
bless_contrast_value: float,
|
|
|
|
|
bless_brightness_op: str,
|
|
|
|
|
bless_brightness_value: float,
|
|
|
|
|
*args
|
|
|
|
|
):
|
|
|
|
|
if bless_enabled and p.sd_model is not None and self.conv_out_weight_original is not None and self.conv_out_bias_original is not None:
|
|
|
|
|
p.sd_model.first_stage_model.decoder.conv_out.weight = self.conv_out_weight_original
|
|
|
|
|
p.sd_model.first_stage_model.decoder.conv_out.bias = self.conv_out_bias_original
|
|
|
|
|
|
|
|
|
|
def _create_ui(self):
|
|
|
|
|
with gr.Group():
|
|
|
|
|
with gr.Accordion('Bless VAE', open=False):
|
|
|
|
|
with gr.Row():
|
|
|
|
|
enabled = gr.Checkbox(label='Enabled', value=False)
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column():
|
|
|
|
|
contrast_op = gr.Radio(['Add', 'Multiply'], value='Multiply', label='Contrast Operation', interactive=True)
|
|
|
|
|
contrast_value = gr.Slider(minimum=-2, maximum=2, value=1, step=0.1, label='Contrast Value')
|
|
|
|
|
with gr.Column():
|
|
|
|
|
brightness_op = gr.Radio(['Add', 'Multiply'], value='Add', label='Brightness Operation', interactive=True)
|
|
|
|
|
brightness_value = gr.Slider(minimum=-2, maximum=2, value=0, step=0.1, label='Brightness Value')
|
|
|
|
|
|
|
|
|
|
return enabled, contrast_op, contrast_value, brightness_op, brightness_value
|
|
|
|
|
|
|
|
|
|
def xyz():
|
|
|
|
|
for scriptDataTuple in scripts.scripts_data:
|
|
|
|
|
if os.path.basename(scriptDataTuple.path) == "xyz_grid.py":
|
|
|
|
|
xy_grid = scriptDataTuple.module
|
|
|
|
|
|
|
|
|
|
def confirm_mode(p, xs):
|
|
|
|
|
for x in xs:
|
|
|
|
|
if x not in ['Add', 'Multiply']:
|
|
|
|
|
raise RuntimeError(f'Invalid op: {x}')
|
|
|
|
|
|
|
|
|
|
contrast_op = xy_grid.AxisOption(
|
|
|
|
|
'[Bless VAE] Contrast Operation',
|
|
|
|
|
str,
|
|
|
|
|
xy_grid.apply_field('bless_contrast_op'),
|
|
|
|
|
confirm=confirm_mode
|
|
|
|
|
)
|
|
|
|
|
contrast_value = xy_grid.AxisOption(
|
|
|
|
|
'[Bless VAE] Contrast Value',
|
|
|
|
|
float,
|
|
|
|
|
xy_grid.apply_field('bless_contrast_value')
|
|
|
|
|
)
|
|
|
|
|
brightness_op = xy_grid.AxisOption(
|
|
|
|
|
'[Bless VAE] Brightness Operation',
|
|
|
|
|
str,
|
|
|
|
|
xy_grid.apply_field('bless_brightness_op'),
|
|
|
|
|
confirm=confirm_mode
|
|
|
|
|
)
|
|
|
|
|
brightness_value = xy_grid.AxisOption(
|
|
|
|
|
'[Bless VAE] Brightness Value',
|
|
|
|
|
float,
|
|
|
|
|
xy_grid.apply_field('bless_brightness_value')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
xy_grid.axis_options.extend([
|
|
|
|
|
contrast_op,
|
|
|
|
|
contrast_value,
|
|
|
|
|
brightness_op,
|
|
|
|
|
brightness_value
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
xyz()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f'Error trying to add XYZ plot options for Bless VAE', e)
|