wd-tagger/app.py

92 lines
2.6 KiB
Python

import argparse
import operator
import warnings
from pathlib import Path
import click
from PIL import Image
from tqdm import tqdm
from transformers import Pipeline, pipeline
import interrogator
warnings.filterwarnings("ignore")
EXTS = [
'.jpg',
'.jpeg',
'.webp',
'.png',
'.tif',
'.tiff'
]
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, required=True)
parser.add_argument('--no-classify', dest='no_classify', default=False, action='store_true') # noqa: E501
args = parser.parse_args()
def get_class(pipeline: Pipeline, file: Image.Image):
data = pipeline(file, top_k=5)
final = {}
for d in data: # type: ignore
final[d["label"]] = d["score"] # type: ignore
return final
def get_tags(
model: interrogator.Interrogator,
file: Path,
pipes: list[Pipeline]=[]
):
tags = []
img = Image.open(file)
res = model.interrogate(img)
if res:
rating = max(res[0], key=res[0].get) # type: ignore
pp_tags = list(model.postprocess_tags(
res[1],
threshold=0.35,
additional_tags=[],
sort_by_alphabetical_order=False,
replace_underscore=True,
escape_tag=False
).keys())
tags = [rating] + pp_tags
if len(pipes) > 0:
classes = [get_class(pipe, img) for pipe in pipes]
classes = [max(c.items(), key=operator.itemgetter(1))[0].replace('_', ' ') for c in classes] # noqa: E501
tags += classes
return tags
if __name__ == '__main__':
files = [p for p in Path(args.path).rglob('*') if p.suffix.lower() in EXTS]
if not len(files) > 0:
print('No files found, exiting')
exit(1)
if not click.confirm(f'Found {len(files)} files. Continue?', default=False):
exit(1)
tagger = interrogator.WaifuDiffusionInterrogator(
'wd14-swinv2-v2-git',
repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
)
tagger.load()
pipes = []
if not args.no_classify:
print('Loading pipelines (ignore all config related errors)')
pipes += [
pipeline("image-classification", "cafeai/cafe_aesthetic", device='cuda:0'),
pipeline("image-classification", "cafeai/cafe_style", device='cuda:0'),
pipeline("image-classification", "cafeai/cafe_waifu", device='cuda:0'),
]
for file in tqdm(files):
tags = get_tags(tagger, file, pipes)
if not len(tags) > 0:
continue
with open(file.parent.joinpath(f"{file.stem}.txt"), 'w+', encoding='utf8') as f:
f.write('\n'.join(tags))