Skip to content

Commit

Permalink
Merge branch 'oobabooga:main' into DualModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Ph0rk0z authored Sep 19, 2023
2 parents b8d2b14 + 029da95 commit 15229d2
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 55 deletions.
6 changes: 3 additions & 3 deletions js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ let isScrolled = false;

targetElement.addEventListener('scroll', function() {
let diff = targetElement.scrollHeight - targetElement.clientHeight;
if(Math.abs(targetElement.scrollTop - diff) <= 1 || diff == 0) {
if(Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0) {
isScrolled = false;
} else {
isScrolled = true;
Expand Down Expand Up @@ -161,7 +161,7 @@ let notebookScrolled = false;

notebookElement.addEventListener('scroll', function() {
let diff = notebookElement.scrollHeight - notebookElement.clientHeight;
if(Math.abs(notebookElement.scrollTop - diff) <= 1 || diff == 0) {
if(Math.abs(notebookElement.scrollTop - diff) <= 10 || diff == 0) {
notebookScrolled = false;
} else {
notebookScrolled = true;
Expand All @@ -186,7 +186,7 @@ let defaultScrolled = false;

defaultElement.addEventListener('scroll', function() {
let diff = defaultElement.scrollHeight - defaultElement.clientHeight;
if(Math.abs(defaultElement.scrollTop - diff) <= 1 || diff == 0) {
if(Math.abs(defaultElement.scrollTop - diff) <= 10 || diff == 0) {
defaultScrolled = false;
} else {
defaultScrolled = true;
Expand Down
2 changes: 1 addition & 1 deletion modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def load_character_memoized(character, name1, name2, instruct=False):


def upload_character(file, img, tavern=False):
decoded_file = file if type(file) == str else file.decode('utf-8')
decoded_file = file if isinstance(file, str) else file.decode('utf-8')
try:
data = json.loads(decoded_file)
except:
Expand Down
32 changes: 16 additions & 16 deletions modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ def from_pretrained(self, path_to_model):
result.generator = generator
return result, result

def encode(self, string, **kwargs):
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)

def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)

return self.tokenizer.decode(ids)[0]

def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()

def generate_with_streaming(self, prompt, state):

# The cache batch size must be 2 for CFG and 1 otherwise
Expand Down Expand Up @@ -211,19 +227,3 @@ def generate(self, prompt, state):
pass

return output

def encode(self, string, **kwargs):
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)

def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)

return self.tokenizer.decode(ids)[0]

def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
24 changes: 20 additions & 4 deletions modules/exllama_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,33 @@ def __call__(self, *args, **kwargs):
seq = past_key_values + seq

seq_tensor = torch.tensor(seq)
reset = True

# Make the forward call
if labels is None:
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length

if longest_prefix > 0:
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)

if reset:
ex_cache.current_seq_len = 0
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True, lora=self.lora)
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)

logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache, lora=self.lora).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, lora=self.lora).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False, lora=self.lora)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, lora=self.lora)

if is_negative:
self.past_seq_negative = seq_tensor
Expand Down
32 changes: 16 additions & 16 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ def from_pretrained(self, path_to_model):
result.generator = generator
return result, result

def encode(self, string, **kwargs):
return self.tokenizer.encode(string, add_bos=True)

def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)

return self.tokenizer.decode(ids)[0]

def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()

def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings()
settings.temperature = state['temperature']
Expand Down Expand Up @@ -114,19 +130,3 @@ def generate(self, prompt, state):
pass

return output

def encode(self, string, **kwargs):
return self.tokenizer.encode(string, add_bos=True)

def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)

return self.tokenizer.decode(ids)[0]

def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
25 changes: 20 additions & 5 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,33 @@ def __call__(self, *args, **kwargs):
seq = past_key_values + seq

seq_tensor = torch.tensor(seq)
reset = True

# Make the forward call
if labels is None:
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length

if longest_prefix > 0:
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)

if reset:
ex_cache.current_seq_len = 0
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True)
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)

logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
# logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False)
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)

if is_negative:
self.past_seq_negative = seq_tensor
Expand Down
5 changes: 3 additions & 2 deletions modules/llamacpp_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ def __call__(self, *args, **kwargs):
longest_prefix = min_length

if longest_prefix > 0:
self.model.n_tokens = longest_prefix
self.model.eval(seq[longest_prefix:])
reset = False
self.model.n_tokens = longest_prefix
if len(seq_tensor) - longest_prefix > 0:
self.model.eval(seq[longest_prefix:])

if reset:
self.model.reset()
Expand Down
4 changes: 2 additions & 2 deletions modules/metadata_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def load_metadata(fname):
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
ti_data_count = struct.unpack("<Q", file.read(8))[0]
kv_data_count = struct.unpack("<Q", file.read(8))[0]
if GGUF_VERSION == 1:

if GGUF_VERSION == 1:
raise Exception('You are using an outdated GGUF, please download a new one.')

for i in range(kv_data_count):
Expand Down
1 change: 0 additions & 1 deletion modules/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import hashlib
import os
import re
import time
Expand Down
2 changes: 1 addition & 1 deletion modules/models_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def update_model_parameters(state, initial=False):
gpu_memories.append(value)
continue

if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
if initial and element in shared.provided_arguments:
continue

# Setting null defaults
Expand Down
7 changes: 6 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import sys
from collections import OrderedDict
from pathlib import Path

import yaml

from modules.logging_colors import logger


# Model variables
model = None
tokenizer = None
Expand Down Expand Up @@ -203,6 +203,11 @@ def str2bool(v):

args = parser.parse_args()
args_defaults = parser.parse_args([])
provided_arguments = []
for arg in sys.argv[1:]:
arg = arg.lstrip('-').replace('-', '_')
if hasattr(args, arg):
provided_arguments.append(arg)

# Loader choosing
if args.autogptq:
Expand Down
4 changes: 2 additions & 2 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())

processor = state.get('logits_processor', LogitsProcessorList([]))
# In case folks just pass in a processor by itself.
if type(processor) != LogitsProcessorList:
# In case a processor is passed by itself.
if not isinstance(processor, LogitsProcessorList):
processor = LogitsProcessorList([processor])
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor
Expand Down
2 changes: 1 addition & 1 deletion modules/ui_file_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_event_handlers():


def load_session(file, state):
decoded_file = file if type(file) == str else file.decode('utf-8')
decoded_file = file if isinstance(file, str) else file.decode('utf-8')
data = json.loads(decoded_file)

if 'character_menu' in data and state.get('character_menu') != data.get('character_menu'):
Expand Down
1 change: 1 addition & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def create_interface():
p = Path(shared.model_name)
if p.exists():
model_name = p.parts[-1]
shared.model_name = model_name
else:
model_name = shared.model_name

Expand Down

0 comments on commit 15229d2

Please sign in to comment.