62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import torch
|
|
import torchvision.transforms.functional as tf
|
|
import torchvision.transforms.v2 as v2
|
|
from .mmaker_color_enhance_core import color_enhance, color_blend
|
|
|
|
class ColorEnhanceComfyNode:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "apply_color_enhance"
|
|
CATEGORY = "postprocessing/Effects"
|
|
|
|
def apply_color_enhance(self, image: torch.Tensor, strength: float):
|
|
images = []
|
|
|
|
for img in image:
|
|
edited_image = v2.ToDtype(dtype=torch.uint8, scale=True)(img).squeeze()
|
|
edited_image = color_enhance(edited_image.detach().cpu().numpy(), strength)
|
|
edited_image = tf.to_tensor(edited_image)
|
|
images.append(edited_image)
|
|
|
|
return (torch.stack(images).permute(0, 2, 3, 1),)
|
|
|
|
class ColorBlendComfyNode:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"image_blend": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "apply_color_enhance"
|
|
CATEGORY = "postprocessing/Effects"
|
|
|
|
def apply_color_enhance(self, image: torch.Tensor, image_blend: torch.Tensor):
|
|
images = []
|
|
image_blend = v2.ToDtype(dtype=torch.uint8, scale=True)(image_blend).squeeze().detach().cpu().numpy()
|
|
|
|
for img in image:
|
|
edited_image = v2.ToDtype(dtype=torch.uint8, scale=True)(img).squeeze()
|
|
edited_image = color_blend(edited_image.detach().cpu().numpy(), image_blend)
|
|
edited_image = tf.to_tensor(edited_image)
|
|
images.append(edited_image)
|
|
|
|
return (torch.stack(images).permute(0, 2, 3, 1),)
|