Model + API: Migrate to use BaseSamplerParams
kwargs is pretty ugly when figuring out which arguments to use. The base requests falls back to defaults anyways, so pass in the params object as is. However, since Python's typing isn't like TypeScript where types can be transformed, the type hinting has a possiblity of None showing up despite there always being a value for some params. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
dcb36e9ab2
commit
3084ef9fa1
5 changed files with 113 additions and 121 deletions
|
|
@ -1,17 +1,13 @@
|
|||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
from functools import partial
|
||||
import aiofiles
|
||||
import asyncio
|
||||
import gc
|
||||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
import torch
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
|
|
@ -32,7 +28,7 @@ from exllamav2.generator import (
|
|||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
|
@ -47,6 +43,7 @@ from backends.exllamav2.utils import (
|
|||
hardware_supports_flash_attn,
|
||||
supports_paged_attn,
|
||||
)
|
||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
|
|
@ -54,6 +51,8 @@ from common.gen_logging import (
|
|||
log_prompt,
|
||||
log_response,
|
||||
)
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import (
|
||||
PromptTemplate,
|
||||
TemplateLoadError,
|
||||
|
|
@ -976,15 +975,20 @@ class ExllamaV2Container:
|
|||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event = None,
|
||||
**kwargs,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
abort_event,
|
||||
mm_embeddings,
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
|
|
@ -1031,21 +1035,22 @@ class ExllamaV2Container:
|
|||
|
||||
return joined_generation
|
||||
|
||||
def check_unsupported_settings(self, **kwargs):
|
||||
def check_unsupported_settings(self, params: BaseSamplerRequest):
|
||||
"""
|
||||
Check and warn the user if a sampler is unsupported.
|
||||
|
||||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
return kwargs
|
||||
return params
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
**kwargs,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
|
@ -1059,46 +1064,43 @@ class ExllamaV2Container:
|
|||
|
||||
prompts = [prompt]
|
||||
|
||||
token_healing = kwargs.get("token_healing")
|
||||
generate_window = max(
|
||||
kwargs.get("generate_window"), self.config.max_seq_len // 8
|
||||
)
|
||||
# TODO: Not used for some reason?
|
||||
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
|
||||
|
||||
# Sampler settings
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
|
||||
# Check unsupported settings for dev wheels
|
||||
kwargs = self.check_unsupported_settings(**kwargs)
|
||||
params = self.check_unsupported_settings(params)
|
||||
|
||||
# Apply settings
|
||||
gen_settings.temperature = kwargs.get("temperature")
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last")
|
||||
gen_settings.smoothing_factor = kwargs.get("smoothing_factor")
|
||||
gen_settings.top_k = kwargs.get("top_k")
|
||||
gen_settings.top_p = kwargs.get("top_p")
|
||||
gen_settings.top_a = kwargs.get("top_a")
|
||||
gen_settings.min_p = kwargs.get("min_p")
|
||||
gen_settings.tfs = kwargs.get("tfs")
|
||||
gen_settings.typical = kwargs.get("typical")
|
||||
gen_settings.mirostat = kwargs.get("mirostat")
|
||||
gen_settings.skew = kwargs.get("skew")
|
||||
gen_settings.temperature = params.temperature
|
||||
gen_settings.temperature_last = params.temperature_last
|
||||
gen_settings.smoothing_factor = params.smoothing_factor
|
||||
gen_settings.top_k = params.top_k
|
||||
gen_settings.top_p = params.top_p
|
||||
gen_settings.top_a = params.top_a
|
||||
gen_settings.min_p = params.min_p
|
||||
gen_settings.tfs = params.tfs
|
||||
gen_settings.typical = params.typical
|
||||
gen_settings.mirostat = params.mirostat
|
||||
gen_settings.skew = params.skew
|
||||
|
||||
# XTC
|
||||
xtc_probability = kwargs.get("xtc_probability")
|
||||
if xtc_probability > 0.0:
|
||||
gen_settings.xtc_probability = xtc_probability
|
||||
if params.xtc_probability > 0.0:
|
||||
gen_settings.xtc_probability = params.xtc_probability
|
||||
|
||||
# 0.1 is the default for this value
|
||||
gen_settings.xtc_threshold = kwargs.get("xtc_threshold")
|
||||
gen_settings.xtc_threshold = params.xtc_threshold
|
||||
|
||||
# DynaTemp settings
|
||||
max_temp = kwargs.get("max_temp")
|
||||
min_temp = kwargs.get("min_temp")
|
||||
max_temp = params.max_temp
|
||||
min_temp = params.min_temp
|
||||
|
||||
if max_temp > min_temp:
|
||||
if params.max_temp > params.min_temp:
|
||||
gen_settings.max_temp = max_temp
|
||||
gen_settings.min_temp = min_temp
|
||||
gen_settings.temp_exponent = kwargs.get("temp_exponent")
|
||||
gen_settings.temp_exponent = params.temp_exponent
|
||||
else:
|
||||
# Force to default values
|
||||
gen_settings.max_temp = 1.0
|
||||
|
|
@ -1115,11 +1117,11 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau")
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta")
|
||||
gen_settings.mirostat_tau = params.mirostat_tau
|
||||
gen_settings.mirostat_eta = params.mirostat_eta
|
||||
|
||||
# Set CFG scale and negative prompt
|
||||
cfg_scale = kwargs.get("cfg_scale")
|
||||
cfg_scale = params.cfg_scale
|
||||
negative_prompt = None
|
||||
if cfg_scale not in [None, 1.0]:
|
||||
if self.paged:
|
||||
|
|
@ -1127,7 +1129,7 @@ class ExllamaV2Container:
|
|||
|
||||
# If the negative prompt is empty, use the BOS token
|
||||
negative_prompt = unwrap(
|
||||
kwargs.get("negative_prompt"), self.tokenizer.bos_token
|
||||
params.negative_prompt, self.tokenizer.bos_token
|
||||
)
|
||||
|
||||
prompts.append(negative_prompt)
|
||||
|
|
@ -1138,15 +1140,16 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Penalties
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty")
|
||||
gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty")
|
||||
gen_settings.token_presence_penalty = kwargs.get("presence_penalty")
|
||||
gen_settings.token_repetition_penalty = params.repetition_penalty
|
||||
gen_settings.token_frequency_penalty = params.frequency_penalty
|
||||
gen_settings.token_presence_penalty = params.presence_penalty
|
||||
|
||||
# Applies for all penalties despite being called token_repetition_range
|
||||
gen_settings.token_repetition_range = unwrap(
|
||||
kwargs.get("penalty_range"), self.config.max_seq_len
|
||||
params.penalty_range, self.config.max_seq_len
|
||||
)
|
||||
|
||||
# TODO: Not used for some reason?
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
|
|
@ -1164,54 +1167,51 @@ class ExllamaV2Container:
|
|||
else:
|
||||
fallback_decay = gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = coalesce(
|
||||
kwargs.get("repetition_decay"), fallback_decay, 0
|
||||
params.repetition_decay, fallback_decay, 0
|
||||
)
|
||||
|
||||
# DRY options
|
||||
dry_multiplier = kwargs.get("dry_multiplier")
|
||||
dry_multiplier = params.dry_multiplier
|
||||
|
||||
# < 0 = disabled
|
||||
if dry_multiplier > 0:
|
||||
gen_settings.dry_multiplier = dry_multiplier
|
||||
|
||||
gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length")
|
||||
gen_settings.dry_base = kwargs.get("dry_base")
|
||||
gen_settings.dry_allowed_length = params.dry_allowed_length
|
||||
gen_settings.dry_base = params.dry_base
|
||||
|
||||
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
|
||||
# Use max_seq_len as the fallback to stay consistent
|
||||
gen_settings.dry_range = unwrap(
|
||||
kwargs.get("dry_range"), self.config.max_seq_len
|
||||
)
|
||||
gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len)
|
||||
|
||||
# Tokenize sequence breakers
|
||||
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
|
||||
if dry_sequence_breakers_json:
|
||||
if params.dry_sequence_breakers:
|
||||
gen_settings.dry_sequence_breakers = {
|
||||
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
|
||||
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
|
||||
}
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = kwargs.get("json_schema")
|
||||
if json_schema:
|
||||
if params.json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, self.model, self.tokenizer
|
||||
params.json_schema, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = kwargs.get("regex_pattern")
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
|
||||
if params.regex_pattern:
|
||||
grammar_handler.add_regex_filter(
|
||||
params.regex_pattern, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = kwargs.get("grammar_string")
|
||||
if grammar_string:
|
||||
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
if params.grammar_string:
|
||||
grammar_handler.add_kbnf_filter(
|
||||
params.grammar_string, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Set banned strings
|
||||
banned_strings = kwargs.get("banned_strings")
|
||||
banned_strings = params.banned_strings
|
||||
if banned_strings and len(grammar_handler.filters) > 0:
|
||||
logger.warning(
|
||||
"Disabling banned_strings because "
|
||||
|
|
@ -1220,16 +1220,12 @@ class ExllamaV2Container:
|
|||
|
||||
banned_strings = []
|
||||
|
||||
stop_conditions = kwargs.get("stop")
|
||||
add_bos_token = kwargs.get("add_bos_token"), True
|
||||
ban_eos_token = kwargs.get("ban_eos_token"), False
|
||||
logit_bias = kwargs.get("logit_bias")
|
||||
|
||||
# Logprobs
|
||||
request_logprobs = kwargs.get("logprobs")
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = params.add_bos_token
|
||||
ban_eos_token = params.ban_eos_token
|
||||
|
||||
# Speculative Ngram
|
||||
self.generator.speculative_ngram = kwargs.get("speculative_ngram")
|
||||
self.generator.speculative_ngram = params.speculative_ngram
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
|
|
@ -1244,17 +1240,15 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Set banned tokens
|
||||
banned_tokens = kwargs.get("banned_tokens")
|
||||
if banned_tokens:
|
||||
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
||||
if params.banned_tokens:
|
||||
gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)
|
||||
|
||||
# Set allowed tokens
|
||||
allowed_tokens = kwargs.get("allowed_tokens")
|
||||
if allowed_tokens:
|
||||
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
|
||||
if params.allowed_tokens:
|
||||
gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)
|
||||
|
||||
# Set logit bias
|
||||
if logit_bias:
|
||||
if params.logit_bias:
|
||||
# Create a vocab tensor if it doesn't exist for token biasing
|
||||
if gen_settings.token_bias is None:
|
||||
padding = -self.tokenizer.config.vocab_size % 32
|
||||
|
|
@ -1264,7 +1258,7 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Map logits to the tensor with their biases
|
||||
for token_id, bias in logit_bias.items():
|
||||
for token_id, bias in params.logit_bias.items():
|
||||
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
|
||||
gen_settings.token_bias[token_id] = bias
|
||||
else:
|
||||
|
|
@ -1289,7 +1283,7 @@ class ExllamaV2Container:
|
|||
stop_conditions += eos_tokens
|
||||
|
||||
# Get multimodal embeddings if present
|
||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
||||
# TODO: Remove kwargs and pass this as optional
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
# Encode both positive and negative prompts
|
||||
|
|
@ -1312,7 +1306,7 @@ class ExllamaV2Container:
|
|||
# Automatically set max_tokens to fill up the context
|
||||
# This should be an OK default, but may be changed in the future
|
||||
max_tokens = unwrap(
|
||||
kwargs.get("max_tokens"),
|
||||
params.max_tokens,
|
||||
self.config.max_seq_len - max(context_len, negative_context_len),
|
||||
)
|
||||
if max_tokens < 1:
|
||||
|
|
@ -1349,12 +1343,6 @@ class ExllamaV2Container:
|
|||
f"is greater than cache_size {self.cache_size}"
|
||||
)
|
||||
|
||||
# Set min_tokens to generate while keeping EOS banned
|
||||
min_tokens = kwargs.get("min_tokens")
|
||||
|
||||
# This is an inverse of skip_special_tokens
|
||||
decode_special_tokens = not kwargs.get("skip_special_tokens")
|
||||
|
||||
# Log prompt to console. Add the BOS token if specified
|
||||
log_prompt(
|
||||
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
||||
|
|
@ -1369,17 +1357,17 @@ class ExllamaV2Container:
|
|||
self.generator,
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
min_new_tokens=min_tokens,
|
||||
min_new_tokens=params.min_tokens,
|
||||
gen_settings=gen_settings,
|
||||
stop_conditions=stop_conditions,
|
||||
decode_special_tokens=decode_special_tokens,
|
||||
decode_special_tokens=not params.skip_special_tokens,
|
||||
filters=grammar_handler.filters,
|
||||
filter_prefer_eos=bool(grammar_handler.filters),
|
||||
return_probs=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
return_probs=params.logprobs > 0,
|
||||
return_top_tokens=params.logprobs,
|
||||
return_logits=params.logprobs > 0,
|
||||
banned_strings=banned_strings,
|
||||
token_healing=token_healing,
|
||||
token_healing=params.token_healing,
|
||||
identifier=job_id,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
|
|
@ -1418,7 +1406,7 @@ class ExllamaV2Container:
|
|||
"offset": len(full_response),
|
||||
}
|
||||
|
||||
if request_logprobs > 0:
|
||||
if params.logprobs > 0:
|
||||
# Get top tokens and probs
|
||||
top_tokens = unwrap(
|
||||
result.get("top_k_tokens"),
|
||||
|
|
@ -1494,8 +1482,7 @@ class ExllamaV2Container:
|
|||
request_id=request_id,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
**kwargs,
|
||||
generate_window=generate_window,
|
||||
**params.model_dump(),
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -282,6 +282,11 @@ class BaseSamplerRequest(BaseModel):
|
|||
ge=0,
|
||||
)
|
||||
|
||||
logprobs: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
||||
ge=0,
|
||||
)
|
||||
|
||||
@field_validator("top_k", mode="before")
|
||||
def convert_top_k(cls, v):
|
||||
"""Fixes instance if Top-K is -1."""
|
||||
|
|
|
|||
|
|
@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[ChatCompletionStreamOptions] = None
|
||||
logprobs: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
||||
ge=0,
|
||||
)
|
||||
response_format: Optional[CompletionResponseFormat] = Field(
|
||||
default_factory=CompletionResponseFormat
|
||||
)
|
||||
|
|
|
|||
|
|
@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
|
|||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
request.state.id,
|
||||
prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
embeddings=embeddings,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -422,10 +422,10 @@ async def generate_chat_completion(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt,
|
||||
request.state.id,
|
||||
embeddings=embeddings,
|
||||
**data.model_dump(exclude={"prompt"}),
|
||||
prompt,
|
||||
data,
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -465,7 +465,6 @@ async def generate_tool_calls(
|
|||
# FIXME: May not be necessary depending on how the codebase evolves
|
||||
tool_data = data.model_copy(deep=True)
|
||||
tool_data.json_schema = tool_data.tool_call_schema
|
||||
gen_params = tool_data.model_dump()
|
||||
|
||||
for idx, gen in enumerate(generations):
|
||||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
|
|
@ -488,10 +487,10 @@ async def generate_tool_calls(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
pre_tool_prompt,
|
||||
request.state.id,
|
||||
pre_tool_prompt,
|
||||
tool_data,
|
||||
embeddings=mm_embeddings,
|
||||
**gen_params,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ import asyncio
|
|||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from common import model
|
||||
from common.auth import get_key_permission
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
|
|
@ -86,16 +86,21 @@ def _create_response(
|
|||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: CompletionRequest,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
abort_event,
|
||||
mm_embeddings,
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
|
@ -195,10 +200,10 @@ async def stream_generate_completion(
|
|||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -256,9 +261,9 @@ async def generate_completion(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue