diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index e8e3cff..58a965f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -192,9 +192,9 @@ class ExllamaV2Container: # Turn off flash attention if CFG is on # Workaround until batched FA2 is fixed in exllamav2 upstream - self.config.no_flash_attn = ( - True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False) - ) + # self.config.no_flash_attn = ( + # True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False) + # ) # Try to set prompt template self.prompt_template = self.find_prompt_template( @@ -509,19 +509,17 @@ class ExllamaV2Container: if value: yield value - batch_size = 2 if self.use_cfg else 1 - if self.cache_mode == "Q4": self.cache = ExLlamaV2Cache_Q4( - self.model, lazy=self.gpu_split_auto, batch_size=batch_size + self.model, lazy=self.gpu_split_auto, batch_size=1 ) elif self.cache_mode == "FP8": self.cache = ExLlamaV2Cache_8bit( - self.model, lazy=self.gpu_split_auto, batch_size=batch_size + self.model, lazy=self.gpu_split_auto, batch_size=1 ) else: self.cache = ExLlamaV2Cache( - self.model, lazy=self.gpu_split_auto, batch_size=batch_size + self.model, lazy=self.gpu_split_auto, batch_size=1 ) # Load model with autosplit @@ -686,6 +684,8 @@ class ExllamaV2Container: for kwargs, check common/sampling.py """ + prompts = [prompt] + token_healing = unwrap(kwargs.get("token_healing"), False) generate_window = max( unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8 @@ -748,10 +748,12 @@ class ExllamaV2Container: negative_prompt = unwrap( kwargs.get("negative_prompt"), self.tokenizer.bos_token ) + + prompts.append(negative_prompt) else: logger.warning( "CFG is currently disabled. " - "Please reload your model with use_cfg = True.", + "If your GPU is supported, reload your model with use_cfg = True" ) gen_settings.token_repetition_penalty = unwrap( @@ -873,22 +875,16 @@ class ExllamaV2Container: else: stop_conditions += eos_tokens - # Tokenized context - ids, offsets = self.tokenizer.encode( - [prompt, negative_prompt] - if negative_prompt and gen_settings.cfg_scale not in [None, 1.0] - else prompt, - add_bos=add_bos_token, - encode_special_tokens=True, - return_offsets=True, - ) - mask = ( - self.tokenizer.padding_mask(ids) - if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0] - else None - ) - context_len = len(ids[0]) + # Encode both positive and negative prompts + input_ids = [ + self.tokenizer.encode( + prompt, add_bos=add_bos_token, encode_special_tokens=True + ) + for prompt in prompts + ] + # The first index will always be the positive prompt + context_len = input_ids[0].size(dim=-1) if context_len > self.config.max_seq_len: logger.warning( f"Context length {context_len} is greater than max_seq_len " @@ -896,12 +892,10 @@ class ExllamaV2Container: "metrics may not be accurate." ) - prompt_tokens = ids.shape[-1] - # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future max_tokens = unwrap( - kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens + kwargs.get("max_tokens"), self.config.max_seq_len - context_len ) # Set min_tokens to generate while keeping EOS banned @@ -941,7 +935,7 @@ class ExllamaV2Container: job_id = uuid.uuid4().hex job = ExLlamaV2DynamicJobAsync( self.generator, - input_ids=ids, + input_ids=input_ids, max_new_tokens=max_tokens, gen_settings=gen_settings, stop_conditions=stop_conditions, @@ -976,7 +970,7 @@ class ExllamaV2Container: generation = { "text": chunk, - "prompt_tokens": prompt_tokens, + "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), }