Skip to content

Commit

Permalink
add Batching to React UI MusicGen (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsxdalv authored Mar 5, 2024
1 parent 52d0a56 commit 4a95eb6
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 32 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ https://rsxdalv.github.io/bark-speaker-directory/
https://github.com/rsxdalv/tts-generation-webui/discussions/186#discussioncomment-7291274

## Changelog
Mar 5:
* Add Batching to React UI MusicGen (#281), thanks to https://github.com/Aamir3d for requesting this and providing feedback

Mar 3:
* Add MMS demo as a notebook
* Add MultiBandDiffusion high VRAM disclaimer
Expand Down
17 changes: 11 additions & 6 deletions react-ui/src/hooks/useLocalStorage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ export default function useLocalStorage<T>(

const setValue: Dispatch<SetStateAction<T>> = (value) => {
// Allow value to be a function so we have the same API as useState
const valueToStore = value instanceof Function ? value(storedValue) : value;

// update local storage
setLocalValue(valueToStore);
// Save state
setStoredValue(valueToStore);
// const valueToStore = value instanceof Function ? value(storedValue) : value;

// // update local storage
// setLocalValue(valueToStore);
// // Save state
// setStoredValue(valueToStore);
setStoredValue(x => {
const newValue = value instanceof Function ? value(x) : value;
setLocalValue(newValue);
return newValue;
});
};

// watch localStorage changes
Expand Down
215 changes: 190 additions & 25 deletions react-ui/src/pages/musicgen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ const ModelSelector = ({
);
};

const initialMusicgenHyperParams = {
iterations: 1,
splitByLines: false,
};

const initialHistory = []; // prevent infinite loop
const MusicgenPage = () => {
const [data, setData] = useLocalStorage<Result | null>(
Expand All @@ -187,23 +192,57 @@ const MusicgenPage = () => {
musicgenId,
initialMusicgenParams
);
// hyperparameters
const [musicgenHyperParams, setMusicgenHyperParams] = useLocalStorage<
typeof initialMusicgenHyperParams
>("musicgenHyperParams", initialMusicgenHyperParams);
const [showLast, setShowLast] = useLocalStorage<number>(
"musicgenShowLast",
10
);
const interrupted = React.useRef(false);

async function musicgen() {
const body = JSON.stringify({
...musicgenParams,
melody: musicgenParams.model.includes("melody")
? musicgenParams.melody
: null,
model: musicgenParams.model,
});
const response = await fetch("/api/gradio/musicgen", {
method: "POST",
body,
});
const [progress, setProgress] = React.useState(0);
const [progressMax, setProgressMax] = React.useState(0);

const result: Result = await response.json();
setData(result);
setHistoryData((x) => [result, ...x]);
async function musicgen() {
interrupted.current = false;
const texts = musicgenHyperParams.splitByLines
? musicgenParams.text.split("\n")
: [musicgenParams.text];

const incrementNonRandomSeed = (seed: number, iteration: number) => {
return seed === -1 ? -1 : seed + iteration;
};

const musicgenIteration = async (text, iteration: number) => {
const result = await musicgenGenerate({
...musicgenParams,
text,
seed: incrementNonRandomSeed(musicgenParams.seed, iteration),
});
setData(result);
setHistoryData((x) => [result, ...x]);
};

setProgress(0);
setProgressMax(texts.length * musicgenHyperParams.iterations);
for (
let iteration = 0;
iteration < musicgenHyperParams.iterations;
iteration++
) {
for (const text of texts) {
if (interrupted.current) {
return;
}
await musicgenIteration(text, iteration);
setProgress((x) => x + 1);
}
}
interrupted.current = false;
setProgress(0);
setProgressMax(0);
}

const handleChange = (
Expand Down Expand Up @@ -272,6 +311,8 @@ const MusicgenPage = () => {
useParameters,
};

const interrupt = () => (interrupted.current = true);

return (
<Template>
<Head>
Expand Down Expand Up @@ -416,6 +457,24 @@ const MusicgenPage = () => {
className="border border-gray-300 p-2 rounded"
/>
</div>

<HyperParameters
params={musicgenHyperParams}
setParams={setMusicgenHyperParams}
progress={progress}
progressMax={progressMax}
interrupted={interrupted}
interrupt={interrupt}
/>
<button
className="border border-gray-300 p-2 rounded"
onClick={() => {
setMusicgenParams(initialMusicgenParams);
setMusicgenHyperParams(initialMusicgenHyperParams);
}}
>
Reset Parameters
</button>
</div>
</div>
</div>
Expand All @@ -438,19 +497,32 @@ const MusicgenPage = () => {

<div className="flex flex-col gap-y-2 border border-gray-300 p-2 rounded">
<label className="text-sm">History:</label>
{/* Clear history */}
<button
className="border border-gray-300 p-2 rounded"
onClick={() => {
setHistoryData([]);
}}
>
Clear History
</button>
<div className="flex gap-x-2 items-center">
<button
className="border border-gray-300 p-2 px-40 rounded"
onClick={() => {
setHistoryData([]);
}}
>
Clear History
</button>
<div className="flex gap-x-2 items-center">
<label className="text-sm">Show Last X entries:</label>
<input
type="number"
value={showLast}
onChange={(event) => setShowLast(Number(event.target.value))}
className="border border-gray-300 p-2 rounded"
min="0"
max="100"
step="1"
/>
</div>
</div>
<div className="flex flex-col gap-y-2">
{historyData &&
historyData
.slice(1, 6)
.slice(1, showLast + 1)
.map((item, index) => (
<AudioOutput
key={index}
Expand All @@ -469,3 +541,96 @@ const MusicgenPage = () => {
};

export default MusicgenPage;

async function musicgenGenerate(musicgenParams: MusicgenParams) {
const body = JSON.stringify({
...musicgenParams,
melody: musicgenParams.model.includes("melody")
? musicgenParams.melody
: null,
model: musicgenParams.model,
});
const response = await fetch("/api/gradio/musicgen", {
method: "POST",
body,
});

return (await response.json()) as Result;
}

const HyperParameters = ({
params: musicgenHyperParams,
setParams: setMusicgenHyperParams,
progress,
progressMax,
interrupted,
interrupt,
}: {
params: typeof initialMusicgenHyperParams;
setParams: React.Dispatch<
React.SetStateAction<typeof initialMusicgenHyperParams>
>;
progress: number;
progressMax: number;
interrupted: React.MutableRefObject<boolean>;
interrupt: () => void;
}) => (
<div className="flex flex-col gap-y-2 border border-gray-300 p-2 rounded">
<label className="text-sm">Hyperparameters:</label>
<div className="flex gap-x-2 items-center">
<label className="text-sm">Iterations:</label>
<input
type="number"
name="iterations"
value={musicgenHyperParams.iterations}
onChange={(event) => {
setMusicgenHyperParams({
...musicgenHyperParams,
iterations: Number(event.target.value),
});
}}
className="border border-gray-300 p-2 rounded"
min="1"
max="10"
step="1"
/>
</div>
<div className="flex gap-x-2 items-center">
<div className="text-sm">Each line as a separate prompt:</div>
<input
type="checkbox"
name="splitByLines"
checked={musicgenHyperParams.splitByLines}
onChange={(event) => {
setMusicgenHyperParams({
...musicgenHyperParams,
splitByLines: event.target.checked,
});
}}
className="border border-gray-300 p-2 rounded"
/>
</div>
<Progress progress={progress} progressMax={progressMax} />
<button className="border border-gray-300 p-2 rounded" onClick={interrupt}>
{interrupted.current ? "Interrupted..." : "Interrupt"}
</button>
</div>
);

const Progress = ({
progress,
progressMax,
}: {
progress: number;
progressMax: number;
}) => (
<div className="flex gap-x-2 items-center">
<label className="text-sm">Progress:</label>
<progress
value={progress}
max={progressMax}
className="[&::-webkit-progress-bar]:rounded [&::-webkit-progress-value]:rounded [&::-webkit-progress-bar]:bg-slate-300 [&::-webkit-progress-value]:bg-orange-400 [&::-moz-progress-bar]:bg-orange-400 [&::-webkit-progress-value]:transition-all [&::-webkit-progress-value]:duration-200"
/>
{progress}/{progressMax}
</div>
);
2 changes: 1 addition & 1 deletion react-ui/src/tabs/MusicgenParams.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export const initialMusicgenParams: MusicgenParams = {
text: "lofi hip hop beats to relax/study to",
melody: undefined,
// melody: "https://www.mfiles.co.uk/mp3-downloads/gs-cd-track2.mp3",
model: "Small",
model: "facebook/musicgen-small",
duration: 1,
topk: 250,
topp: 0,
Expand Down

0 comments on commit 4a95eb6

Please sign in to comment.