master
MMaker 2023-02-02 23:04:15 -05:00
commit 960587a813
Signed by: mmaker
GPG Key ID: CCE79B8FEDA40FB2
7 changed files with 347 additions and 0 deletions

1
.gitignore vendored 100644
View File

@ -0,0 +1 @@
__pycache__

5
README.md 100644
View File

@ -0,0 +1,5 @@
# Additional Networks API
Extension for [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) that interfaces with [kohya-ss/sd-webui-additional-networks](https://github.com/kohya-ss/sd-webui-additional-networks) to add an API layer.
Currently this is only a barebones implementation to display informational previews for LoRA model cards in the extra networks menu on hover.

42
addnet_api/api.py 100644
View File

@ -0,0 +1,42 @@
# Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger
from typing import Callable
from threading import Lock
from secrets import compare_digest
from modules import shared
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock, prefix: str = '') -> None:
if shared.cmd_opts.api_auth:
self.credentials = dict()
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
self.credentials[user] = password
self.app = app
self.queue_lock = queue_lock
self.prefix = prefix
def auth(self, creds: HTTPBasicCredentials = Depends(HTTPBasic())):
if creds.username in self.credentials:
if compare_digest(creds.password, self.credentials[creds.username]):
return True
raise HTTPException(
status_code=401,
detail="Incorrect username or password",
headers={
"WWW-Authenticate": "Basic"
})
def add_api_route(self, path: str, endpoint: Callable, **kwargs):
if self.prefix:
path = f'{self.prefix}/{path}'
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)

View File

@ -0,0 +1,8 @@
from pydantic import BaseModel, Field
class AddnetApiModelRequest(BaseModel):
model_path: str = Field(
title='Model',
description='The absolute path of the model.'
)

View File

@ -0,0 +1,200 @@
ADDNET_API_SELECTOR_LORA_CARDS = 'div[id$="_lora_cards"]';
ADDNET_API_ROOT = window.location.origin + '/addnet-api/v1';
ADDNET_API_ENDPOINT_LORA_METADATA = ADDNET_API_ROOT + '/lora/metadata';
BASE64_IMG = 'data:image/png;charset=utf-8;base64,';
const addNetAPI = async () => {
class AddnetTooltip {
constructor() {
this.tooltip = document.createElement('div');
this.tooltip.id = 'addnet_api_tooltip';
document.addEventListener('mousemove', (evt) => {
this.hover(evt);
})
const title = document.createElement('span');
title.id = 'addnet_api_tooltip_title';
const content = document.createElement('span');
content.id = 'addnet_api_tooltip_content';
const img = document.createElement('img');
const desc = document.createElement('span');
desc.id = 'addnet_api_tooltip_description';
const table = document.createElement('div')
table.id = 'addnet_api_tooltip_table';
const tableRoot = document.createElement('table');
for (const row of ['Keywords', 'Source', 'Rating', 'Tags', 'Author']) {
const tr = document.createElement('tr');
tr.id = 'addnet_api_tooltip_table_' + row.toLowerCase();
const rowTitle = document.createElement('td');
const rowValue = document.createElement('td');
rowTitle.textContent = row;
tr.appendChild(rowTitle);
tr.appendChild(rowValue);
tableRoot.appendChild(tr);
}
table.appendChild(tableRoot);
this.content = content;
this.title = title;
this.img = img;
this.desc = desc;
this.table = table;
this.tooltip.appendChild(title);
this.content.appendChild(img);
this.content.appendChild(desc);
this.content.appendChild(table);
this.tooltip.appendChild(content);
gradioApp().appendChild(this.tooltip);
}
hover(evt) {
// Adapted from https://github.com/ccd0/4chan-x/blob/master/src/General/UI.coffee
const height = this.tooltip.offsetHeight;
const width = this.tooltip.offsetWidth;
const top = Math.max(0, evt.clientY * (window.innerHeight - height) / window.innerHeight);
let threshold = window.innerWidth / 2;
let marginX = (evt.clientX <= threshold ? evt.clientX : window.innerWidth - evt.clientX) + 45;
marginX = Math.min(marginX, window.innerWidth - width);
marginX += "px";
if (evt.clientX <= threshold) {
this.tooltip.style.left = marginX;
this.tooltip.style.right = '';
} else {
this.tooltip.style.left = '';
this.tooltip.style.right = marginX;
}
this.tooltip.style.top = top + "px"
}
displayMetadata(metadata) {
this.title.textContent = metadata['ssmd_display_name'] || '';
this.img.src = `${BASE64_IMG}${metadata['ssmd_cover_images'][0]}` || '';
this.desc.textContent = metadata['ssmd_description'] || '';
for (const [key, value] of Object.entries(metadata)) {
if (key in ['ssmd_display_name', 'ssmd_cover_images', 'ssmd_description']) {
continue;
}
const name = key.split('ssmd_')[1];
const row = this.table.querySelector(`#addnet_api_tooltip_table_${name}`);
if (row && value && value !== '0') {
row.style.display = 'table-row';
row.children[1].textContent = value;
} else if (row) {
row.style.display = 'none';
}
}
this.content.style.display = 'inline-block';
}
displayNoMetadata() {
this.title.textContent = 'No metadata.'
this.content.style.display = 'none';
}
show() {
this.tooltip.style.display = 'block';
}
hide() {
this.tooltip.style.display = 'none';
}
}
async function fetchLoraMetadata(card, modelPath) {
const res = await fetch(
ADDNET_API_ENDPOINT_LORA_METADATA,
{
method: 'POST',
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json'
},
body: JSON.stringify({
'model_path': modelPath
})
}
)
if (res.ok) {
const content = await res.json();
const metadata = Object.fromEntries(Object.entries(content).filter(([key]) => key.startsWith('ss_')));
const userMetadata = Object.fromEntries(Object.entries(content).filter(([key]) => key.startsWith('ssmd_')));
if (Object.keys(userMetadata).length > 0) {
if ('ssmd_cover_images' in userMetadata && userMetadata['ssmd_cover_images'].length > 6) {
userMetadata['ssmd_cover_images'] = JSON.parse(userMetadata['ssmd_cover_images']);
if (!card.style.backgroundImage) {
card.style.backgroundImage = `url("${BASE64_IMG}${userMetadata['ssmd_cover_images'][0]}")`
}
}
return userMetadata;
}
}
}
// Start
let timeout;
const tooltip = new AddnetTooltip();
const cards = gradioApp().querySelectorAll(`${ADDNET_API_SELECTOR_LORA_CARDS} .card`)
for (const card of cards) {
let modelPath = card.querySelector('.search_term').textContent;
if (modelPath) {
modelPath = modelPath.replace(/^[\\\/]+/,"");
card.addEventListener('mouseover', (evt) => {
timeout = setTimeout(async () => {
const metadata = await fetchLoraMetadata(card, modelPath);
if (metadata) {
tooltip.displayMetadata(metadata);
tooltip.show();
} else {
tooltip.displayNoMetadata();
tooltip.show();
}
}, 100);
})
card.addEventListener('mouseout', (evt) => {
tooltip.hide();
clearTimeout(timeout);
})
}
}
console.debug(`Added event listeners to ${cards.length} cards`);
};
document.addEventListener("DOMContentLoaded", () => {
const onload = () => {
if (gradioApp().querySelectorAll(ADDNET_API_SELECTOR_LORA_CARDS).length >= 2) {
addNetAPI();
} else {
setTimeout(onload, 100);
}
};
onload();
});

View File

@ -0,0 +1,49 @@
import json
import os
import sys
from pathlib import Path
from addnet_api import api_models as models
from addnet_api.api import Api
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from modules import script_callbacks, scripts, shared
from modules.call_queue import queue_lock
# NOTE: This is a really jank setup to import the model_util module,
# and to make sure we don't rescan a bogus models path,
# since model_util calls update_models at the top level.
extensions_dir = Path(os.path.abspath(__file__)).parents[2]
addnet_ext_dir = extensions_dir / 'sd-webui-additional-networks'
if addnet_ext_dir.is_dir():
current_basedir = scripts.current_basedir
scripts.current_basedir = str(addnet_ext_dir)
sys.path.append(str(extensions_dir))
from scripts import model_util
scripts.current_basedir = current_basedir
sys.path.remove(str(extensions_dir))
def get_lora_metadata(req: models.AddnetApiModelRequest):
if req.model_path:
model = Path(shared.cmd_opts.lora_dir) / req.model_path
if model.is_file():
metadata = model_util.read_model_metadata(str(model), 'LoRA')
return Response(content=json.dumps(metadata), media_type='application/json')
raise HTTPException(404, 'LoRA model not found')
def on_app_started(_, app: FastAPI):
api = Api(app, queue_lock, '/addnet-api/v1')
api.add_api_route(
'lora/metadata',
get_lora_metadata,
methods=['POST']
)
script_callbacks.on_app_started(on_app_started)

42
style.css 100644
View File

@ -0,0 +1,42 @@
#addnet_api_tooltip {
font-family: Source Sans Pro,ui-sans-serif,system-ui,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Helvetica Neue,Arial,Noto Sans,sans-serif,"Apple Color Emoji","Segoe UI Emoji",Segoe UI Symbol,"Noto Color Emoji";
position: fixed;
display: none;
background: rgba(0,0,0,0.9);
border: 1px solid white;
padding: 0.5em;
z-index: 10000;
color: white;
user-select: none !important;
pointer-events: none !important;
}
#addnet_api_tooltip img {
display: block;
max-width: 480px;
max-height: 480px;
margin: 0 auto 0.5em;
}
#addnet_api_tooltip_title {
display: block;
font-size: 1.5em;
margin-bottom: 0.25em;
}
#addnet_api_tooltip_description {
margin-bottom: 1em;
}
td:nth-child(2n+1) {
font-weight: bold;
text-align: right;
}
td:nth-child(2n) {
padding-left: 1em;
}
tr.addnet_api_tooltip_table_source td:nth-child(2) {
line-break: anywhere;
}