92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import argparse
|
|
import operator
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
import click
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from transformers import Pipeline, pipeline
|
|
|
|
import interrogator
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
EXTS = [
|
|
'.jpg',
|
|
'.jpeg',
|
|
'.webp',
|
|
'.png',
|
|
'.tif',
|
|
'.tiff'
|
|
]
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-p', '--path', type=str, required=True)
|
|
parser.add_argument('--no-classify', dest='no_classify', default=False, action='store_true') # noqa: E501
|
|
args = parser.parse_args()
|
|
|
|
def get_class(pipeline: Pipeline, file: Image.Image):
|
|
data = pipeline(file, top_k=5)
|
|
final = {}
|
|
for d in data: # type: ignore
|
|
final[d["label"]] = d["score"] # type: ignore
|
|
return final
|
|
|
|
def get_tags(
|
|
model: interrogator.Interrogator,
|
|
file: Path,
|
|
pipes: list[Pipeline]=[]
|
|
):
|
|
|
|
tags = []
|
|
img = Image.open(file)
|
|
res = model.interrogate(img)
|
|
if res:
|
|
rating = max(res[0], key=res[0].get) # type: ignore
|
|
pp_tags = list(model.postprocess_tags(
|
|
res[1],
|
|
threshold=0.35,
|
|
additional_tags=[],
|
|
sort_by_alphabetical_order=False,
|
|
replace_underscore=True,
|
|
escape_tag=False
|
|
).keys())
|
|
tags = [rating] + pp_tags
|
|
|
|
if len(pipes) > 0:
|
|
classes = [get_class(pipe, img) for pipe in pipes]
|
|
classes = [max(c.items(), key=operator.itemgetter(1))[0].replace('_', ' ') for c in classes] # noqa: E501
|
|
tags += classes
|
|
|
|
return tags
|
|
|
|
if __name__ == '__main__':
|
|
files = [p for p in Path(args.path).rglob('*') if p.suffix.lower() in EXTS]
|
|
if not len(files) > 0:
|
|
print('No files found, exiting')
|
|
exit(1)
|
|
if not click.confirm(f'Found {len(files)} files. Continue?', default=False):
|
|
exit(1)
|
|
|
|
tagger = interrogator.WaifuDiffusionInterrogator(
|
|
'wd14-swinv2-v2-git',
|
|
repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
|
|
)
|
|
tagger.load()
|
|
pipes = []
|
|
|
|
if not args.no_classify:
|
|
print('Loading pipelines (ignore all config related errors)')
|
|
pipes += [
|
|
pipeline("image-classification", "cafeai/cafe_aesthetic", device='cuda:0'),
|
|
pipeline("image-classification", "cafeai/cafe_style", device='cuda:0'),
|
|
pipeline("image-classification", "cafeai/cafe_waifu", device='cuda:0'),
|
|
]
|
|
|
|
for file in tqdm(files):
|
|
tags = get_tags(tagger, file, pipes)
|
|
if not len(tags) > 0:
|
|
continue
|
|
with open(file.parent.joinpath(f"{file.stem}.txt"), 'w+', encoding='utf8') as f:
|
|
f.write('\n'.join(tags)) |