Add classifier pipelines

master
MMaker 2023-05-26 18:57:59 -04:00
parent 356c8e85f7
commit cf8753896e
Signed by: mmaker
GPG Key ID: CCE79B8FEDA40FB2
3 changed files with 47 additions and 8 deletions

5
.gitignore vendored
View File

@ -1,2 +1,5 @@
*.dll
__pycache__
__pycache__/
.vscode/
test/

49
app.py
View File

@ -1,12 +1,18 @@
import argparse
import operator
import time
import warnings
from pathlib import Path
import click
from PIL import Image
from tqdm import tqdm
from transformers import Pipeline, pipeline
import interrogator
warnings.filterwarnings("ignore")
EXTS = [
'.jpg',
'.jpeg',
@ -18,13 +24,26 @@ EXTS = [
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, required=True)
parser.add_argument('--no-classify', dest='no_classify', default=False, action='store_true') # noqa: E501
args = parser.parse_args()
def get_tags(model: interrogator.Interrogator, file: Path):
def get_class(pipeline: Pipeline, file: Image.Image):
data = pipeline(file, top_k=5)
final = {}
for d in data: # type: ignore
final[d["label"]] = d["score"] # type: ignore
return final
def get_tags(
model: interrogator.Interrogator,
file: Path,
pipes: list[Pipeline]=[]
):
start = time.time()
tags = []
res = model.interrogate(Image.open(file))
img = Image.open(file)
res = model.interrogate(img)
if res:
rating = max(res[0], key=res[0].get) # type: ignore
pp_tags = list(model.postprocess_tags(
@ -37,23 +56,39 @@ def get_tags(model: interrogator.Interrogator, file: Path):
).keys())
tags = [rating] + pp_tags
classes = [get_class(pipe, img) for pipe in pipes]
classes = [max(c.items(), key=operator.itemgetter(1))[0].replace('_', ' ') for c in classes] # noqa: E501
tags += classes
end = time.time()
print(f"{end - start:.2f}s - {str(file)}")
# print(f"{end - start:.2f}s - {str(file)}")
return tags
if __name__ == '__main__':
files = [p for p in Path(args.path).rglob('*') if p.suffix.lower() in EXTS]
if not len(files) > 0:
print('No files found, exiting')
exit(1)
if not click.confirm(f'Found {len(files)} files. Continue?', default=False):
exit(1)
model = interrogator.WaifuDiffusionInterrogator(
tagger = interrogator.WaifuDiffusionInterrogator(
'wd14-swinv2-v2-git',
repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
)
model.load()
tagger.load()
pipes = []
for file in files:
tags = get_tags(model, file)
if not args.no_classify:
print('Loading pipelines (ignore all config related errors)')
pipes += [
pipeline("image-classification", "cafeai/cafe_aesthetic", device="cuda:0"),
pipeline("image-classification", "cafeai/cafe_style", device="cuda:0"),
pipeline("image-classification", "cafeai/cafe_waifu", device="cuda:0"),
]
for file in tqdm(files):
tags = get_tags(tagger, file, pipes)
if not len(tags) > 0:
continue
with open(file.parent.joinpath(f"{file.stem}.txt"), 'w+', encoding='utf8') as f:

View File

@ -7,3 +7,4 @@ pandas==2.0.1
Pillow==9.5.0
protobuf==4.23.2
tqdm==4.65.0
transformers==4.29.2