Skip to content

Commit

Permalink
Merge pull request #23 from GeraudBourdin/main
Browse files Browse the repository at this point in the history
Fill in the middle
  • Loading branch information
GeraudBourdin authored Sep 10, 2024
2 parents 0fec491 + 8461918 commit 957526b
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 23 deletions.
69 changes: 67 additions & 2 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Api is the same as the main Mistral api :
- **Chat Completions**: Generate conversational responses and complete dialogue prompts using Mistral's language models.
- **Chat Completions Streaming**: Establish a real-time stream of chat completions, ideal for applications requiring continuous interaction.
- **Embeddings**: Obtain numerical vector representations of text, enabling semantic search, clustering, and other machine learning applications.
- **Fill in the Middle**: Automatically generate code by setting a starting prompt and an optional suffix, allowing the model to complete the code in between. Ideal for creating specific code segments within predefined boundaries.

## Getting Started

Expand Down Expand Up @@ -215,6 +216,72 @@ Array
)
```


#### Fill in the middle
```php
$prompt = "Write response in php:\n";
$prompt .= "/** Calculate date + n days. Returns \DateTime object */";
$suffix = 'return $datePlusNdays;\n}';

try {
$result = $client->fim(
prompt: $prompt,
suffix: $suffix,
params:[
'model' => $model_name,
'temperature' => 0.7,
'top_p' => 1,
'max_tokens' => 200,
'min_tokens' => 0,
'stop' => 'string',
'random_seed' => 0
]
);
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
}
```
Result :
```console
function datePlusNdays(\DateTime $date, $n) {
$datePlusNdays = clone $date;
$datePlusNdays->add(new \DateInterval('P'.abs($n).'D'));
```

#### Fill in the middle in stream mode
```php
try {
$result = $client->fimStream(
prompt: $prompt,
suffix: $suffix,
params:[
'model' => $model_name,
'temperature' => 0.7,
'top_p' => 1,
'max_tokens' => 200,
'min_tokens' => 0,
'stop' => 'string',
'random_seed' => 0
]
);
foreach ($result as $chunk) {
echo $chunk->getChunk();
}
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
}
```
Result :
```console
function datePlusNdays(\DateTime $date, $n) {
$datePlusNdays = clone $date;
$datePlusNdays->add(new \DateInterval('P'.abs($n).'D'));
```


## Lama.cpp inference
[MistralAi La plateforme](https://console.mistral.ai/) is really cheap you should consider subscribing to it instead of running
a local Lama.cpp instance. This bundle cost us only 0.02€ during our tests. If you really feel you need a local server, here is a
Expand Down Expand Up @@ -499,8 +566,6 @@ example:
)
```



## Documentation

For detailed documentation on the Mistral AI API and the available endpoints, please refer to the [Mistral AI API Documentation](https://docs.mistral.ai).
Expand Down
70 changes: 70 additions & 0 deletions examples/fill_in_the_middle.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/php
<?php
require_once __DIR__ . '/../vendor/autoload.php';

use Partitech\PhpMistral\MistralClient;
use Partitech\PhpMistral\MistralClientException;
use Partitech\PhpMistral\Messages;

// export MISTRAL_API_KEY=your_api_key
$apiKey = getenv('MISTRAL_API_KEY');
$model_name = "codestral-2405";

$client = new MistralClient($apiKey);

$prompt = "Write response in php:\n";
$prompt .= "/** Calculate date + n days. Returns \DateTime object */";
$suffix = 'return $datePlusNdays;\n}';

try {
$result = $client->fim(
prompt: $prompt,
suffix: $suffix,
params:[
'model' => $model_name,
'temperature' => 0.7,
'top_p' => 1,
'max_tokens' => 200,
'min_tokens' => 0,
'stop' => 'string',
'random_seed' => 0
]
);
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
}

print_r($result->getMessage());

/**
* function datePlusNdays(\DateTime $date, $n) {
* $datePlusNdays = clone $date;
* $datePlusNdays->add(new \DateInterval('P'.abs($n).'D'));
*/

###############################################
##### Fill in the meddle with streaming ######
###############################################

try {
$result = $client->fimStream(
prompt: $prompt,
suffix: $suffix,
params:[
'model' => $model_name,
'temperature' => 0.7,
'top_p' => 1,
'max_tokens' => 200,
'min_tokens' => 0,
'stop' => 'string',
'random_seed' => 0
]
);
foreach ($result as $chunk) {
echo $chunk->getChunk();
}
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
}
1 change: 1 addition & 0 deletions src/Messages.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public function format(string $format=MistralClient::CHAT_ML): string|array|null
return $messages;
}

/** @deprecated since v0.0.16. Will be removed in the future version. */
if(MistralClient::COMPLETION === $format) {
$messages = null;

Expand Down
138 changes: 117 additions & 21 deletions src/MistralClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

class MistralClient
{
const string DEFAULT_MODEL = 'open-mistral-7b';
const string DEFAULT_CHAT_MODEL = 'open-mistral-7b';
const string DEFAULT_FIM_MODEL = 'codestral-2405';
const string TOOL_CHOICE_ANY = 'any';
const string TOOL_CHOICE_AUTO = 'auto';
const string TOOL_CHOICE_NONE = 'none';
Expand All @@ -31,9 +32,12 @@ class MistralClient
];
protected const string END_OF_STREAM = "[DONE]";
const string ENDPOINT = 'https://api.mistral.ai';
/** @deprecated since v0.0.16. Will be removed in the future version. */
public const string CHAT_ML = 'mistral';
/** @deprecated since v0.0.16. Will be removed in the future version. */
public const string COMPLETION = 'completion';
protected string $completionEndpoint = 'v1/chat/completions';
protected string $chatCompletionEndpoint = 'v1/chat/completions';
protected string $fimCompletionEndpoint = 'v1/fim/completions';
protected string $promptKeyword = 'messages';
protected string $apiKey;
protected string $url;
Expand Down Expand Up @@ -75,9 +79,10 @@ public function listModels(): array
protected function request(
string $method,
string $path,
array $request = [],
bool $stream = false
): array|ResponseInterface {
array $request = [],
bool $stream = false
): array|ResponseInterface
{
try {
$response = $this->httpClient->request(
$method,
Expand All @@ -101,10 +106,96 @@ protected function request(
public function chat(Messages $messages, array $params = []): Response
{
$params = $this->makeChatCompletionRequest($messages, $params, false);
$result = $this->request('POST', $this->completionEndpoint, $params);
$result = $this->request('POST', $this->chatCompletionEndpoint, $params);
return Response::createFromArray($result);
}

/**
* @throws MistralClientException
*/
public function fim(string $prompt, ?string $suffix, array $params = []): Response
{
$request = $this->makeFimCompletionRequest(
prompt: $prompt,
suffix: $suffix,
params: $params,
stream: false
);

$result = $this->request('POST', $this->fimCompletionEndpoint, $request);
return Response::createFromArray($result);
}

/**
* @throws MistralClientException
*/
public function fimStream(string $prompt, ?string $suffix, array $params = []): Generator
{
$request = $this->makeFimCompletionRequest(
prompt: $prompt,
suffix: $suffix,
params: $params,
stream: true
);

$stream = $this->request('POST', $this->fimCompletionEndpoint, $request, true);
return $this->getStream($stream);
}

protected function makeFimCompletionRequest(string $prompt, ?string $suffix = null, array $params = [], bool $stream = false): array
{
$return = [];

$return['stream'] = $stream;
$return['prompt'] = $prompt;

if (!is_null($suffix)) {
$return['suffix'] = $suffix;
} else {
$return['suffix'] = '';
}


if (isset($params['model']) && is_string($params['model'])) {
$return['model'] = $params['model'];
} else {
$return['model'] = self::DEFAULT_FIM_MODEL;
}

if (isset($params['temperature']) && is_float($params['temperature'])) {
$return['temperature'] = $params['temperature'];
}

if (isset($params['top_p']) && is_float($params['top_p'])) {
$return['top_p'] = $params['top_p'];
}

if (isset($params['max_tokens']) && is_int($params['max_tokens'])) {
$return['max_tokens'] = $params['max_tokens'];
} else {
$return['max_tokens'] = null;
}

if (isset($params['min_tokens']) && is_numeric($params['min_tokens'])) {
$return['min_tokens'] = (int)$params['min_tokens'];
} else {
$return['min_tokens'] = null;
}

if (isset($params['stop']) && is_string($params['stop'])) {
$return['stop'] = (string)$params['stop'];
}

if (isset($params['min_tokens']) && is_numeric($params['min_tokens'])) {
$return['min_tokens'] = (int)$params['min_tokens'];
}

if (isset($params['random_seed']) && is_int($params['random_seed'])) {
$return['random_seed'] = $params['random_seed'];
}

return $return;
}

/**
* @param Messages $messages
Expand All @@ -121,7 +212,7 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
if (isset($params['model']) && is_string($params['model'])) {
$return['model'] = $params['model'];
} else {
$return['model'] = self::DEFAULT_MODEL;
$return['model'] = self::DEFAULT_CHAT_MODEL;
}

if ($this->mode === self::CHAT_ML) {
Expand Down Expand Up @@ -184,8 +275,8 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
}

if (isset($params['presence_penalty']) && is_numeric(
$params['presence_penalty']
) && $params['presence_penalty'] >= -2 && $params['presence_penalty'] <= 2) {
$params['presence_penalty']
) && $params['presence_penalty'] >= -2 && $params['presence_penalty'] <= 2) {
$return['presence_penalty'] = (float)$params['presence_penalty'];
}

Expand Down Expand Up @@ -224,7 +315,7 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
$return['guided_json'] = json_encode($params['guided_json']);
}

if(isset($params['response_format']) && $params['response_format'] === self::RESPONSE_FORMAT_JSON) {
if (isset($params['response_format']) && $params['response_format'] === self::RESPONSE_FORMAT_JSON) {
$return['response_format'] = [
'type' => 'json_object'
];
Expand All @@ -240,8 +331,23 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
public function chatStream(Messages $messages, array $params = []): Generator
{
$request = $this->makeChatCompletionRequest($messages, $params, true);
$stream = $this->request('POST', $this->completionEndpoint, $request, true);
$stream = $this->request('POST', $this->chatCompletionEndpoint, $request, true);
return $this->getStream($stream);
}


/**
* @throws MistralClientException
*/
public function embeddings(array $datas): array
{
$request = ['model' => 'mistral-embed', 'input' => $datas,];
return $this->request('POST', 'v1/embeddings', $request);
}


public function getStream($stream): Generator
{
$response = null;
foreach ($this->httpClient->stream($stream) as $chunk) {
try {
Expand Down Expand Up @@ -272,14 +378,4 @@ public function chatStream(Messages $messages, array $params = []): Generator
}
}
}


/**
* @throws MistralClientException
*/
public function embeddings(array $datas): array
{
$request = ['model' => 'mistral-embed', 'input' => $datas,];
return $this->request('POST', 'v1/embeddings', $request);
}
}

0 comments on commit 957526b

Please sign in to comment.