master
MMaker 2023-05-26 19:04:46 -04:00
parent cf8753896e
commit 62d199cba9
Signed by: mmaker
GPG Key ID: CCE79B8FEDA40FB2
1 changed files with 7 additions and 10 deletions

17
app.py
View File

@ -1,6 +1,5 @@
import argparse
import operator
import time
import warnings
from pathlib import Path
@ -39,7 +38,6 @@ def get_tags(
file: Path,
pipes: list[Pipeline]=[]
):
start = time.time()
tags = []
img = Image.open(file)
@ -56,12 +54,11 @@ def get_tags(
).keys())
tags = [rating] + pp_tags
classes = [get_class(pipe, img) for pipe in pipes]
classes = [max(c.items(), key=operator.itemgetter(1))[0].replace('_', ' ') for c in classes] # noqa: E501
tags += classes
if len(pipes) > 0:
classes = [get_class(pipe, img) for pipe in pipes]
classes = [max(c.items(), key=operator.itemgetter(1))[0].replace('_', ' ') for c in classes] # noqa: E501
tags += classes
end = time.time()
# print(f"{end - start:.2f}s - {str(file)}")
return tags
if __name__ == '__main__':
@ -82,9 +79,9 @@ if __name__ == '__main__':
if not args.no_classify:
print('Loading pipelines (ignore all config related errors)')
pipes += [
pipeline("image-classification", "cafeai/cafe_aesthetic", device="cuda:0"),
pipeline("image-classification", "cafeai/cafe_style", device="cuda:0"),
pipeline("image-classification", "cafeai/cafe_waifu", device="cuda:0"),
pipeline("image-classification", "cafeai/cafe_aesthetic", device='cuda:0'),
pipeline("image-classification", "cafeai/cafe_style", device='cuda:0'),
pipeline("image-classification", "cafeai/cafe_waifu", device='cuda:0'),
]
for file in tqdm(files):