Add classifier pipelines
parent
356c8e85f7
commit
cf8753896e
|
@ -1,2 +1,5 @@
|
|||
*.dll
|
||||
__pycache__
|
||||
|
||||
__pycache__/
|
||||
.vscode/
|
||||
test/
|
49
app.py
49
app.py
|
@ -1,12 +1,18 @@
|
|||
import argparse
|
||||
import operator
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import Pipeline, pipeline
|
||||
|
||||
import interrogator
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
EXTS = [
|
||||
'.jpg',
|
||||
'.jpeg',
|
||||
|
@ -18,13 +24,26 @@ EXTS = [
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, required=True)
|
||||
parser.add_argument('--no-classify', dest='no_classify', default=False, action='store_true') # noqa: E501
|
||||
args = parser.parse_args()
|
||||
|
||||
def get_tags(model: interrogator.Interrogator, file: Path):
|
||||
def get_class(pipeline: Pipeline, file: Image.Image):
|
||||
data = pipeline(file, top_k=5)
|
||||
final = {}
|
||||
for d in data: # type: ignore
|
||||
final[d["label"]] = d["score"] # type: ignore
|
||||
return final
|
||||
|
||||
def get_tags(
|
||||
model: interrogator.Interrogator,
|
||||
file: Path,
|
||||
pipes: list[Pipeline]=[]
|
||||
):
|
||||
start = time.time()
|
||||
|
||||
tags = []
|
||||
res = model.interrogate(Image.open(file))
|
||||
img = Image.open(file)
|
||||
res = model.interrogate(img)
|
||||
if res:
|
||||
rating = max(res[0], key=res[0].get) # type: ignore
|
||||
pp_tags = list(model.postprocess_tags(
|
||||
|
@ -37,23 +56,39 @@ def get_tags(model: interrogator.Interrogator, file: Path):
|
|||
).keys())
|
||||
tags = [rating] + pp_tags
|
||||
|
||||
classes = [get_class(pipe, img) for pipe in pipes]
|
||||
classes = [max(c.items(), key=operator.itemgetter(1))[0].replace('_', ' ') for c in classes] # noqa: E501
|
||||
tags += classes
|
||||
|
||||
end = time.time()
|
||||
print(f"{end - start:.2f}s - {str(file)}")
|
||||
# print(f"{end - start:.2f}s - {str(file)}")
|
||||
return tags
|
||||
|
||||
if __name__ == '__main__':
|
||||
files = [p for p in Path(args.path).rglob('*') if p.suffix.lower() in EXTS]
|
||||
if not len(files) > 0:
|
||||
print('No files found, exiting')
|
||||
exit(1)
|
||||
if not click.confirm(f'Found {len(files)} files. Continue?', default=False):
|
||||
exit(1)
|
||||
|
||||
model = interrogator.WaifuDiffusionInterrogator(
|
||||
tagger = interrogator.WaifuDiffusionInterrogator(
|
||||
'wd14-swinv2-v2-git',
|
||||
repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
|
||||
)
|
||||
model.load()
|
||||
tagger.load()
|
||||
pipes = []
|
||||
|
||||
for file in files:
|
||||
tags = get_tags(model, file)
|
||||
if not args.no_classify:
|
||||
print('Loading pipelines (ignore all config related errors)')
|
||||
pipes += [
|
||||
pipeline("image-classification", "cafeai/cafe_aesthetic", device="cuda:0"),
|
||||
pipeline("image-classification", "cafeai/cafe_style", device="cuda:0"),
|
||||
pipeline("image-classification", "cafeai/cafe_waifu", device="cuda:0"),
|
||||
]
|
||||
|
||||
for file in tqdm(files):
|
||||
tags = get_tags(tagger, file, pipes)
|
||||
if not len(tags) > 0:
|
||||
continue
|
||||
with open(file.parent.joinpath(f"{file.stem}.txt"), 'w+', encoding='utf8') as f:
|
||||
|
|
|
@ -7,3 +7,4 @@ pandas==2.0.1
|
|||
Pillow==9.5.0
|
||||
protobuf==4.23.2
|
||||
tqdm==4.65.0
|
||||
transformers==4.29.2
|
Loading…
Reference in New Issue