From bbb17372d8911bda0f034a8ee7fbde73463efb33 Mon Sep 17 00:00:00 2001 From: geraud Date: Tue, 10 Sep 2024 10:47:29 +0200 Subject: [PATCH 1/2] Add fill in the middle. --- Readme.md | 67 ++++++++++++++++ examples/fill_in_the_middle.php | 70 ++++++++++++++++ src/Messages.php | 1 + src/MistralClient.php | 138 +++++++++++++++++++++++++++----- 4 files changed, 255 insertions(+), 21 deletions(-) create mode 100755 examples/fill_in_the_middle.php diff --git a/Readme.md b/Readme.md index 2a6db12..330f692 100755 --- a/Readme.md +++ b/Readme.md @@ -11,6 +11,7 @@ Api is the same as the main Mistral api : - **Chat Completions**: Generate conversational responses and complete dialogue prompts using Mistral's language models. - **Chat Completions Streaming**: Establish a real-time stream of chat completions, ideal for applications requiring continuous interaction. - **Embeddings**: Obtain numerical vector representations of text, enabling semantic search, clustering, and other machine learning applications. +- **Fill in the Middle**: Automatically generate code by setting a starting prompt and an optional suffix, allowing the model to complete the code in between. Ideal for creating specific code segments within predefined boundaries. ## Getting Started @@ -215,6 +216,72 @@ Array ) ``` + + +#### Fill in the middle +```php +$prompt = "Write response in php:\n"; +$prompt .= "/** Calculate date + n days. Returns \DateTime object */"; +$suffix = 'return $datePlusNdays;\n}'; + +try { + $result = $client->fim( + prompt: $prompt, + suffix: $suffix, + params:[ + 'model' => $model_name, + 'temperature' => 0.7, + 'top_p' => 1, + 'max_tokens' => 200, + 'min_tokens' => 0, + 'stop' => 'string', + 'random_seed' => 0 + ] + ); +} catch (MistralClientException $e) { + echo $e->getMessage(); + exit(1); +} +``` +Result : +```console +function datePlusNdays(\DateTime $date, $n) { + $datePlusNdays = clone $date; + $datePlusNdays->add(new \DateInterval('P'.abs($n).'D')); +``` + +#### Fill in the middle in stream mode +```php +try { + $result = $client->fimStream( + prompt: $prompt, + suffix: $suffix, + params:[ + 'model' => $model_name, + 'temperature' => 0.7, + 'top_p' => 1, + 'max_tokens' => 200, + 'min_tokens' => 0, + 'stop' => 'string', + 'random_seed' => 0 + ] + ); + foreach ($result as $chunk) { + echo $chunk->getChunk(); + } +} catch (MistralClientException $e) { + echo $e->getMessage(); + exit(1); +} +``` +Result : +```console +function datePlusNdays(\DateTime $date, $n) { + $datePlusNdays = clone $date; + $datePlusNdays->add(new \DateInterval('P'.abs($n).'D')); +``` + + ## Lama.cpp inference [MistralAi La plateforme](https://console.mistral.ai/) is really cheap you should consider subscribing to it instead of running a local Lama.cpp instance. This bundle cost us only 0.02€ during our tests. If you really feel you need a local server, here is a diff --git a/examples/fill_in_the_middle.php b/examples/fill_in_the_middle.php new file mode 100755 index 0000000..0f4046f --- /dev/null +++ b/examples/fill_in_the_middle.php @@ -0,0 +1,70 @@ +#!/usr/bin/php +fim( + prompt: $prompt, + suffix: $suffix, + params:[ + 'model' => $model_name, + 'temperature' => 0.7, + 'top_p' => 1, + 'max_tokens' => 200, + 'min_tokens' => 0, + 'stop' => 'string', + 'random_seed' => 0 + ] + ); +} catch (MistralClientException $e) { + echo $e->getMessage(); + exit(1); +} + +print_r($result->getMessage()); + +/** + * function datePlusNdays(\DateTime $date, $n) { + * $datePlusNdays = clone $date; + * $datePlusNdays->add(new \DateInterval('P'.abs($n).'D')); + */ + +############################################### +##### Fill in the meddle with streaming ###### +############################################### + +try { + $result = $client->fimStream( + prompt: $prompt, + suffix: $suffix, + params:[ + 'model' => $model_name, + 'temperature' => 0.7, + 'top_p' => 1, + 'max_tokens' => 200, + 'min_tokens' => 0, + 'stop' => 'string', + 'random_seed' => 0 + ] + ); + foreach ($result as $chunk) { + echo $chunk->getChunk(); + } +} catch (MistralClientException $e) { + echo $e->getMessage(); + exit(1); +} \ No newline at end of file diff --git a/src/Messages.php b/src/Messages.php index 365314d..2e09a7d 100644 --- a/src/Messages.php +++ b/src/Messages.php @@ -33,6 +33,7 @@ public function format(string $format=MistralClient::CHAT_ML): string|array|null return $messages; } + /** @deprecated since v0.0.16. Will be removed in the future version. */ if(MistralClient::COMPLETION === $format) { $messages = null; diff --git a/src/MistralClient.php b/src/MistralClient.php index 044b07a..140ab54 100644 --- a/src/MistralClient.php +++ b/src/MistralClient.php @@ -15,7 +15,8 @@ class MistralClient { - const string DEFAULT_MODEL = 'open-mistral-7b'; + const string DEFAULT_CHAT_MODEL = 'open-mistral-7b'; + const string DEFAULT_FIM_MODEL = 'codestral-2405'; const string TOOL_CHOICE_ANY = 'any'; const string TOOL_CHOICE_AUTO = 'auto'; const string TOOL_CHOICE_NONE = 'none'; @@ -31,9 +32,12 @@ class MistralClient ]; protected const string END_OF_STREAM = "[DONE]"; const string ENDPOINT = 'https://api.mistral.ai'; + /** @deprecated since v0.0.16. Will be removed in the future version. */ public const string CHAT_ML = 'mistral'; + /** @deprecated since v0.0.16. Will be removed in the future version. */ public const string COMPLETION = 'completion'; - protected string $completionEndpoint = 'v1/chat/completions'; + protected string $chatCompletionEndpoint = 'v1/chat/completions'; + protected string $fimCompletionEndpoint = 'v1/fim/completions'; protected string $promptKeyword = 'messages'; protected string $apiKey; protected string $url; @@ -75,9 +79,10 @@ public function listModels(): array protected function request( string $method, string $path, - array $request = [], - bool $stream = false - ): array|ResponseInterface { + array $request = [], + bool $stream = false + ): array|ResponseInterface + { try { $response = $this->httpClient->request( $method, @@ -101,10 +106,96 @@ protected function request( public function chat(Messages $messages, array $params = []): Response { $params = $this->makeChatCompletionRequest($messages, $params, false); - $result = $this->request('POST', $this->completionEndpoint, $params); + $result = $this->request('POST', $this->chatCompletionEndpoint, $params); return Response::createFromArray($result); } + /** + * @throws MistralClientException + */ + public function fim(string $prompt, ?string $suffix, array $params = []): Response + { + $request = $this->makeFimCompletionRequest( + prompt: $prompt, + suffix: $suffix, + params: $params, + stream: false + ); + + $result = $this->request('POST', $this->fimCompletionEndpoint, $request); + return Response::createFromArray($result); + } + + /** + * @throws MistralClientException + */ + public function fimStream(string $prompt, ?string $suffix, array $params = []): Generator + { + $request = $this->makeFimCompletionRequest( + prompt: $prompt, + suffix: $suffix, + params: $params, + stream: true + ); + + $stream = $this->request('POST', $this->fimCompletionEndpoint, $request, true); + return $this->getStream($stream); + } + + protected function makeFimCompletionRequest(string $prompt, ?string $suffix = null, array $params = [], bool $stream = false): array + { + $return = []; + + $return['stream'] = $stream; + $return['prompt'] = $prompt; + + if (!is_null($suffix)) { + $return['suffix'] = $suffix; + } else { + $return['suffix'] = ''; + } + + + if (isset($params['model']) && is_string($params['model'])) { + $return['model'] = $params['model']; + } else { + $return['model'] = self::DEFAULT_FIM_MODEL; + } + + if (isset($params['temperature']) && is_float($params['temperature'])) { + $return['temperature'] = $params['temperature']; + } + + if (isset($params['top_p']) && is_float($params['top_p'])) { + $return['top_p'] = $params['top_p']; + } + + if (isset($params['max_tokens']) && is_int($params['max_tokens'])) { + $return['max_tokens'] = $params['max_tokens']; + } else { + $return['max_tokens'] = null; + } + + if (isset($params['min_tokens']) && is_numeric($params['min_tokens'])) { + $return['min_tokens'] = (int)$params['min_tokens']; + } else { + $return['min_tokens'] = null; + } + + if (isset($params['stop']) && is_string($params['stop'])) { + $return['stop'] = (string)$params['stop']; + } + + if (isset($params['min_tokens']) && is_numeric($params['min_tokens'])) { + $return['min_tokens'] = (int)$params['min_tokens']; + } + + if (isset($params['random_seed']) && is_int($params['random_seed'])) { + $return['random_seed'] = $params['random_seed']; + } + + return $return; + } /** * @param Messages $messages @@ -121,7 +212,7 @@ protected function makeChatCompletionRequest(Messages $messages, array $params, if (isset($params['model']) && is_string($params['model'])) { $return['model'] = $params['model']; } else { - $return['model'] = self::DEFAULT_MODEL; + $return['model'] = self::DEFAULT_CHAT_MODEL; } if ($this->mode === self::CHAT_ML) { @@ -184,8 +275,8 @@ protected function makeChatCompletionRequest(Messages $messages, array $params, } if (isset($params['presence_penalty']) && is_numeric( - $params['presence_penalty'] - ) && $params['presence_penalty'] >= -2 && $params['presence_penalty'] <= 2) { + $params['presence_penalty'] + ) && $params['presence_penalty'] >= -2 && $params['presence_penalty'] <= 2) { $return['presence_penalty'] = (float)$params['presence_penalty']; } @@ -224,7 +315,7 @@ protected function makeChatCompletionRequest(Messages $messages, array $params, $return['guided_json'] = json_encode($params['guided_json']); } - if(isset($params['response_format']) && $params['response_format'] === self::RESPONSE_FORMAT_JSON) { + if (isset($params['response_format']) && $params['response_format'] === self::RESPONSE_FORMAT_JSON) { $return['response_format'] = [ 'type' => 'json_object' ]; @@ -240,8 +331,23 @@ protected function makeChatCompletionRequest(Messages $messages, array $params, public function chatStream(Messages $messages, array $params = []): Generator { $request = $this->makeChatCompletionRequest($messages, $params, true); - $stream = $this->request('POST', $this->completionEndpoint, $request, true); + $stream = $this->request('POST', $this->chatCompletionEndpoint, $request, true); + return $this->getStream($stream); + } + + /** + * @throws MistralClientException + */ + public function embeddings(array $datas): array + { + $request = ['model' => 'mistral-embed', 'input' => $datas,]; + return $this->request('POST', 'v1/embeddings', $request); + } + + + public function getStream($stream): Generator + { $response = null; foreach ($this->httpClient->stream($stream) as $chunk) { try { @@ -272,14 +378,4 @@ public function chatStream(Messages $messages, array $params = []): Generator } } } - - - /** - * @throws MistralClientException - */ - public function embeddings(array $datas): array - { - $request = ['model' => 'mistral-embed', 'input' => $datas,]; - return $this->request('POST', 'v1/embeddings', $request); - } } From 84619184b770435d6c87fca05daa43701429d850 Mon Sep 17 00:00:00 2001 From: geraud Date: Tue, 10 Sep 2024 10:48:34 +0200 Subject: [PATCH 2/2] Add fill in the middle. --- Readme.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/Readme.md b/Readme.md index 330f692..7fae274 100755 --- a/Readme.md +++ b/Readme.md @@ -566,8 +566,6 @@ example: ) ``` - - ## Documentation For detailed documentation on the Mistral AI API and the available endpoints, please refer to the [Mistral AI API Documentation](https://docs.mistral.ai).