Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃殌feat: WIP Custom Parameters for Stable Diffusion Profiles #2433

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 31 additions & 1 deletion api/app/clients/tools/structured/StableDiffusion.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ const sharp = require('sharp');
const { v4: uuidv4 } = require('uuid');
const { StructuredTool } = require('langchain/tools');
const { FileContext } = require('librechat-data-provider');
const { Preference } = require('~/models');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const paths = require('~/config/paths');
const { logger } = require('~/config');

Expand Down Expand Up @@ -75,7 +77,19 @@ class StableDiffusionAPI extends StructuredTool {
}

async _call(data) {
const url = this.url;
const customConfig = await getCustomConfig();
const sdConfig = customConfig.tools.stableDiffusion;
//const sdProfileName = "SDXL Turbo";
const prefName = 'sdProfile';
const prefUser = this.userId;
const sdPreference = await Preference.findOne({ prefUser, prefName }).lean();
console.log('sdProfile',sdPreference);
//const sdProfileName = sdProfile.name
//console.log('sdProfileName',sdProfileName);
const sdProfileObject = sdConfig.filter(obj => obj.name === sdPreference.value);
const sdProfileUrl = sdProfileObject[0].webUI;
console.log('sdProfileUrl:',sdProfileUrl);
const payloadParameters = Object.keys(sdProfileObject[0].parameters);
const { prompt, negative_prompt } = data;
const payload = {
prompt,
Expand All @@ -86,6 +100,22 @@ class StableDiffusionAPI extends StructuredTool {
width: 1024,
height: 1024,
};
for (const parameter of payloadParameters) {
payload[parameter] = sdProfileObject[0].parameters[parameter];
}
console.log('payload',payload);
//console.log('paylaod',payloadParameters);
console.log('################ CHECK #####################');
// Check if webUI is defined, if so use that over env variable
let url = '';
if (sdProfileUrl) {
url = sdProfileUrl;
console.log('url is from librechat.yaml');
} else {
url = this.url;
console.log('url is from env variable');
}

const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
const image = generationResponse.data.images[0];

Expand Down
4 changes: 4 additions & 0 deletions api/models/Preference.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
const mongoose = require('mongoose');
const preferenceSchema = require('./schema/preference');

module.exports = mongoose.model('Preference', preferenceSchema);
2 changes: 2 additions & 0 deletions api/models/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ const {
updateFileUsage,
} = require('./File');
const Key = require('./Key');
const Preference = require('./Preference');
const User = require('./User');
const Session = require('./Session');
const Balance = require('./Balance');

module.exports = {
User,
Key,
Preference,
Session,
Balance,

Expand Down
19 changes: 19 additions & 0 deletions api/models/schema/preference.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
const mongoose = require('mongoose');

const preferenceSchema = mongoose.Schema({
userId: {
type: mongoose.Schema.Types.ObjectId,
ref: 'User',
required: true,
},
name: {
type: String,
required: true,
},
value: {
type: String,
required: true,
},
});

module.exports = preferenceSchema;
1 change: 1 addition & 0 deletions api/server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ const startServer = async () => {
// API Endpoints
app.use('/api/auth', routes.auth);
app.use('/api/keys', routes.keys);
app.use('/api/preferences', routes.preferences);
app.use('/api/user', routes.user);
app.use('/api/search', routes.search);
app.use('/api/ask', routes.ask);
Expand Down
1 change: 1 addition & 0 deletions api/server/routes/config.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ router.get('/', async function (req, res) {
process.env.SHOW_BIRTHDAY_ICON === '',
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
interface: req.app.locals.interface,
tools: req.app.locals.tools,
};

if (typeof process.env.CUSTOM_FOOTER === 'string') {
Expand Down
2 changes: 2 additions & 0 deletions api/server/routes/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const search = require('./search');
const tokenizer = require('./tokenizer');
const auth = require('./auth');
const keys = require('./keys');
const preferences = require('./preferences');
const oauth = require('./oauth');
const endpoints = require('./endpoints');
const balance = require('./balance');
Expand All @@ -29,6 +30,7 @@ module.exports = {
prompts,
auth,
keys,
preferences,
oauth,
user,
tokenizer,
Expand Down
21 changes: 21 additions & 0 deletions api/server/routes/preferences.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
const express = require('express');
const router = express.Router();
const { updatePreference, getPreference } = require('../services/UserService');
const { requireJwtAuth } = require('../middleware/');

router.put('/', requireJwtAuth, async (req, res) => {
console.log('pref.js PUT BEFORE',req.user.id, req.body);
await updatePreference({ userId: req.user.id, ...req.body });
console.log('pref.js PUT AFTER',req.user.id, req.body);
res.status(201).send();
});

router.get('/', requireJwtAuth, async (req, res) => {
const { name } = req.query;
console.log('pref.js GET BEFORE', name);
const response = await getPreference({ userId: req.user.id, name });
console.log('pref.js GET AFTER', name);
res.status(200).send(response);
});

module.exports = router;
1 change: 1 addition & 0 deletions api/server/services/AppService.js
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ const AppService = async (app) => {
fileStrategy,
fileConfig: config?.fileConfig,
interface: config?.interface,
tools: config?.tools,
paths,
...endpointLocals,
};
Expand Down
25 changes: 24 additions & 1 deletion api/server/services/UserService.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const { User, Key } = require('~/models');
const { User, Key, Preference } = require('~/models');
const { encrypt, decrypt } = require('~/server/utils');
const { logger } = require('~/config');

Expand Down Expand Up @@ -68,11 +68,34 @@ const checkUserKeyExpiry = (expiresAt, message) => {
}
};

const getPreference = async ({ userId, name }) => {
const preferenceValue = await Preference.findOne({ userId, name }).lean();
console.log('getPrefName',name);
console.log('getPrefValue',preferenceValue);
if (preferenceValue) {
return preferenceValue.value;
}
};

const updatePreference = async ({ userId, name, value }) => {
return await Preference.findOneAndUpdate(
{ userId, name },
{
userId,
name,
value,
},
{ upsert: true, new: true },
).lean();
};

module.exports = {
updateUserPluginsService,
getUserKey,
getUserKeyExpiry,
updateUserKey,
deleteUserKey,
checkUserKeyExpiry,
getPreference,
updatePreference,
};
19 changes: 17 additions & 2 deletions client/src/components/Nav/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import * as Tabs from '@radix-ui/react-tabs';
import { SettingsTabValues } from 'librechat-data-provider';
import type { TDialogProps } from '~/common';
import { Dialog, DialogContent, DialogHeader, DialogTitle } from '~/components/ui';
import { GearIcon, DataIcon, UserIcon, ExperimentIcon } from '~/components/svg';
import { General, Beta, Data, Account } from './SettingsTabs';
import { GearIcon, DataIcon, UserIcon, ExperimentIcon, MinimalPlugin } from '~/components/svg';
import { General, Beta, Data, Account, Tools } from './SettingsTabs';
import { useMediaQuery, useLocalize } from '~/hooks';
import { cn } from '~/utils';

Expand Down Expand Up @@ -96,11 +96,26 @@ export default function Settings({ open, onOpenChange }: TDialogProps) {
<UserIcon />
{localize('com_nav_setting_account')}
</Tabs.Trigger>
<Tabs.Trigger
className={cn(
'group m-1 flex items-center justify-start gap-2 rounded-md px-2 py-1.5 text-sm text-black radix-state-active:bg-white radix-state-active:text-black dark:text-white dark:radix-state-active:bg-gray-600',
isSmallScreen
? 'flex-1 flex-col items-center justify-center text-sm dark:text-gray-500 dark:radix-state-active:text-white'
: 'bg-white radix-state-active:bg-gray-200',
isSmallScreen ? '' : 'dark:bg-gray-700',
)}
value={SettingsTabValues.TOOLS}
style={{ userSelect: 'none' }}
>
<MinimalPlugin className="icon-sm" />
{localize('com_nav_setting_tools')}
</Tabs.Trigger>
</Tabs.List>
<General />
<Beta />
<Data />
<Account />
<Tools />
</Tabs.Root>
</div>
</DialogContent>
Expand Down
82 changes: 82 additions & 0 deletions client/src/components/Nav/SettingsTabs/Tools/StableDiffusion.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import * as Tabs from '@radix-ui/react-tabs';
import React, { useCallback, useEffect, useState } from 'react';
import {
useGetStartupConfig,
usePreferenceQuery,
useUpdatePreferenceMutation,
} from 'librechat-data-provider/react-query';
import {
useLocalize,
useLocalStorage,
} from '~/hooks';
import { Dropdown } from '~/components/ui';
import { Spinner } from '~/components/svg';
//import store from '~/store';

export const SDProfileSelector = ({
sdprofile,
onChange,
}: {
sdprofile: string;
onChange: (value: string) => void;
}) => {
const localize = useLocalize();

// Create an array of options for the Dropdown
const { data: config } = useGetStartupConfig();
const sdConfig = config?.tools?.stableDiffusion;
let sdProfileOptions = [];
if (sdConfig) {
sdProfileOptions = sdConfig.map(obj => {
return {
value: obj.name,
display: obj.name
};
});
} else {
sdProfileOptions = [
{ value: 'Default', display: localize('com_nav_sd_default') },
];
}

return (
<div className="flex items-center justify-between">
<div> {localize('com_nav_sd_profile')} </div>
<Dropdown value={sdprofile} onChange={onChange} options={sdProfileOptions} />
</div>
);
};

export default function StableDiffusion() {
const { data: sdProfilePref, isLoading } = usePreferenceQuery('sdProfile');
const [selectedSDProfile, setSelectedSDProfile] = React.useState(sdProfilePref || 'Default');

useEffect(() => {
if (sdProfilePref) setSelectedSDProfile(sdProfilePref);
}, [sdProfilePref, isLoading]);

const updatePreference = useUpdatePreferenceMutation();
const changeSDProfile = useCallback(
(value: string) => {
setSelectedSDProfile(value);
console.log('callback',value);
updatePreference.mutate({
name: 'sdProfile',
value: value,
});
},
[updatePreference],
);

return (
<div className="flex flex-col gap-3 text-sm text-gray-600 dark:text-gray-50">
<div className="border-b pb-3 last-of-type:border-b-0 dark:border-gray-700">
{isLoading ? (
<Spinner className="opacity-0" />
) : (
<SDProfileSelector sdprofile={selectedSDProfile} onChange={changeSDProfile} />
)}
</div>
</div>
);
}
23 changes: 23 additions & 0 deletions client/src/components/Nav/SettingsTabs/Tools/Tools.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import React from 'react';
import * as Tabs from '@radix-ui/react-tabs';
import { SettingsTabValues } from 'librechat-data-provider';
import StableDiffusion from './StableDiffusion';

function Tools() {
return (
<Tabs.Content
value={SettingsTabValues.TOOLS}
role="tabpanel"
className="w-full md:min-h-[300px]"
>
<div className="flex flex-col gap-3 text-sm text-gray-600 dark:text-gray-50">
<div className="border-b pb-3 last-of-type:border-b-0 dark:border-gray-700">
<StableDiffusion />
</div>
</div>
<div className="border-b pb-3 last-of-type:border-b-0 dark:border-gray-700"></div>
</Tabs.Content>
);
}

export default React.memo(Tools);
1 change: 1 addition & 0 deletions client/src/components/Nav/SettingsTabs/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export { default as Data } from './Data/Data';
export { default as Beta } from './Beta/Beta';
export { RevokeKeysButton } from './Data/Data';
export { default as Account } from './Account/Account';
export { default as Tools } from './Tools/Tools';
3 changes: 3 additions & 0 deletions client/src/localization/languages/Eng.ts
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ export default {
com_nav_setting_beta: 'Beta features',
com_nav_setting_data: 'Data controls',
com_nav_setting_account: 'Account',
com_nav_setting_tools: 'Tools',
com_nav_sd_default: 'Default',
com_nav_sd_profile: 'Stable Diffusion Profile',
com_nav_language: 'Language',
com_nav_lang_auto: 'Auto detect',
com_nav_lang_english: 'English',
Expand Down
6 changes: 6 additions & 0 deletions packages/data-provider/src/api-endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ export const revokeUserKey = (name: string) => `${keysEndpoint}/${name}`;

export const revokeAllUserKeys = () => `${keysEndpoint}?all=true`;

const preferencesEndpoint = '/api/preferences';

export const preferences = () => preferencesEndpoint;

export const preferencesQuery = (name: string) => `${preferencesEndpoint}?name=${name}`;

export const abortRequest = (endpoint: string) => `/api/ask/${endpoint}/abort`;

export const conversations = (pageNumber: string) => `/api/convos?pageNumber=${pageNumber}`;
Expand Down