diff --git a/cpp/actions.hpp b/cpp/actions.hpp index ed79eff..079488c 100644 --- a/cpp/actions.hpp +++ b/cpp/actions.hpp @@ -173,6 +173,12 @@ glue_msg_load_res action_load(app_t &app, const char *req_raw) cparams.offload_kqv = req.offload_kqv.value; if (req.n_batch.not_null()) cparams.n_batch = req.n_batch.value; + if (req.n_ubatch.not_null()) + cparams.n_ubatch = req.n_ubatch.value; + if (req.flash_attn.not_null()) + cparams.flash_attn = req.flash_attn.value; + if (req.n_threads_decoding.not_null()) + cparams.n_threads = req.n_threads_decoding.value; if (req.n_seq_max.not_null()) cparams.n_seq_max = req.n_seq_max.value; if (req.pooling_type.not_null()) diff --git a/cpp/glue.hpp b/cpp/glue.hpp index 2cdc81b..e0f4106 100644 --- a/cpp/glue.hpp +++ b/cpp/glue.hpp @@ -513,6 +513,9 @@ struct glue_msg_load_req GLUE_FIELD_NULLABLE(int, yarn_orig_ctx) GLUE_FIELD_NULLABLE(str, cache_type_k) GLUE_FIELD_NULLABLE(str, cache_type_v) + GLUE_FIELD_NULLABLE(int, n_ubatch) + GLUE_FIELD_NULLABLE(bool, flash_attn) + GLUE_FIELD_NULLABLE(int, n_threads_decoding) }; struct glue_msg_load_res diff --git a/scripts/docker-compose.yml b/scripts/docker-compose.yml index 27e8794..bab138a 100644 --- a/scripts/docker-compose.yml +++ b/scripts/docker-compose.yml @@ -26,7 +26,7 @@ services: # emcc --clear-cache - emcmake cmake ../.. + emcmake cmake -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_LLAMAFILE=OFF -DGGML_CPU_AARCH64=OFF ../.. export EMCC_CFLAGS="$$SHARED_EMCC_CFLAGS" emmake make wllama -j @@ -34,7 +34,7 @@ services: mkdir -p wasm/multi-thread cd wasm/multi-thread export EMCC_CFLAGS="" # temporary clear it - emcmake cmake ../.. + emcmake cmake -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_LLAMAFILE=OFF -DGGML_CPU_AARCH64=OFF ../.. export EMCC_CFLAGS="$$SHARED_EMCC_CFLAGS -pthread -sUSE_PTHREADS=1 -sPTHREAD_POOL_SIZE=Module[\\\"pthreadPoolSize\\\"]" emmake make wllama -j diff --git a/src/utils.ts b/src/utils.ts index e0f217f..ff17a8c 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -114,7 +114,7 @@ export const sortFileByShard = (blobs: Blob[]): void => { export const delay = (ms: number) => new Promise((r) => setTimeout(r, ms)); export const absoluteUrl = (relativePath: string) => - new URL(relativePath, document.baseURI).href; + relativePath; export const padDigits = (number: number, digits: number) => { return ( diff --git a/src/wllama.ts b/src/wllama.ts index c2996f5..1b6cb9e 100644 --- a/src/wllama.ts +++ b/src/wllama.ts @@ -565,6 +565,7 @@ export class Wllama { yarn_orig_ctx: config.yarn_orig_ctx, cache_type_k: config.cache_type_k as string, cache_type_v: config.cache_type_v as string, + ...config, }); const loadedCtxInfo: LoadedContextInfo = { ...loadResult, @@ -673,7 +674,7 @@ export class Wllama { await this.samplingInit(this.samplingConfig); const stopTokens = new Set(options.stopTokens ?? []); // process prompt - let tokens = await this.tokenize(prompt, true); + let tokens = Array.isArray(prompt) ? prompt : await this.tokenize(prompt, true); if (this.addBosToken && tokens[0] !== this.bosToken) { tokens.unshift(this.bosToken); }