Skip to content

Commit

Permalink
Merge pull request #26 from GeraudBourdin/timeout
Browse files Browse the repository at this point in the history
Timeout
  • Loading branch information
GeraudBourdin authored Oct 24, 2024
2 parents 97ec189 + dce481d commit 690aad3
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
46 changes: 46 additions & 0 deletions examples/socket_timeout.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<?php
require_once __DIR__ . '/../vendor/autoload.php';

use Partitech\PhpMistral\MistralClient;
use Partitech\PhpMistral\MistralClientException;
use Partitech\PhpMistral\Messages;
use Symfony\Component\HttpClient\Exception\TransportException;

// export MISTRAL_API_KEY=
$apiKey = getenv('MISTRAL_API_KEY');
$client = new MistralClient(apiKey: $apiKey, timeout:0.1);
$messages = new Messages();
$messages->addUserMessage('What is the best French cheese?');

$params = [
'model' => 'mistral-large-latest',
'temperature' => 0.7,
'top_p' => 1,
'max_tokens' => null,
'safe_prompt' => false,
'random_seed' => null
];

try {
foreach ($client->chatStream($messages, $params) as $chunk) {
echo $chunk->getChunk();
}
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
} catch (TransportException $e) {
echo 'Idle timeout reached' . PHP_EOL;

}

$client->setTimeout(null);
try {
foreach ($client->chatStream($messages, $params) as $chunk) {
echo $chunk->getChunk();
}
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
} catch (TransportException $e) {
echo 'Idle timeout reached' . PHP_EOL;;
}
40 changes: 37 additions & 3 deletions src/MistralClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

use Generator;
use KnpLabs\JsonSchema\ObjectSchema;
use Symfony\Component\HttpClient\Exception\TransportException;
use Symfony\Component\HttpClient\HttpClient;
use Symfony\Component\HttpClient\Retry\GenericRetryStrategy;
use Symfony\Component\HttpClient\RetryableHttpClient;
use Symfony\Contracts\HttpClient\Exception\TransportExceptionInterface;
use Symfony\Contracts\HttpClient\HttpClientInterface;
use Symfony\Contracts\HttpClient\ResponseInterface;
use Throwable;
Expand Down Expand Up @@ -43,9 +45,11 @@ class MistralClient
protected string $url;
private HttpClientInterface $httpClient;
private string $mode;
private null|int|float $timeout = null; // null = default_socket_timeout

public function __construct(string $apiKey, string $url = self::ENDPOINT)
public function __construct(string $apiKey, string $url = self::ENDPOINT, int|float $timeout = null)
{
$this->setTimeout($timeout);
$this->httpClient = new RetryableHttpClient(
HttpClient::create(),
new GenericRetryStrategy(self::RETRY_STATUS_CODES, 500, 2)
Expand Down Expand Up @@ -83,11 +87,21 @@ protected function request(
bool $stream = false
): array|ResponseInterface
{
$params = [
'json' => $request,
'headers' => ['Authorization' => 'Bearer ' . $this->apiKey,],
'buffer' => $stream,
];

if(!is_null($this->getTimeout())){
$params['timeout'] = $this->getTimeout();
}

try {
$response = $this->httpClient->request(
$method,
$this->url . '/' . $path,
['json' => $request, 'headers' => ['Authorization' => 'Bearer ' . $this->apiKey,], 'buffer' => $stream,]
$params
);
} catch (Throwable $e) {
throw new MistralClientException($e->getMessage(), $e->getCode(), $e);
Expand Down Expand Up @@ -127,7 +141,7 @@ public function fim(string $prompt, ?string $suffix, array $params = []): Respon
}

/**
* @throws MistralClientException
* @throws MistralClientException|TransportExceptionInterface
*/
public function fimStream(string $prompt, ?string $suffix, array $params = []): Generator
{
Expand Down Expand Up @@ -327,6 +341,7 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,

/**
* @throws MistralClientException
* @throws TransportExceptionInterface
*/
public function chatStream(Messages $messages, array $params = []): Generator
{
Expand All @@ -346,10 +361,17 @@ public function embeddings(array $datas): array
}


/**
* @throws MistralClientException
* @throws TransportExceptionInterface
*/
public function getStream($stream): Generator
{
$response = null;
foreach ($this->httpClient->stream($stream) as $chunk) {
if ($chunk->isTimeout()) {
throw new TransportException('Stream is closed');
}
try {
$chunk = trim($chunk->getContent());
} catch (Throwable $e) {
Expand Down Expand Up @@ -378,4 +400,16 @@ public function getStream($stream): Generator
}
}
}

public function setTimeout(null|int|float $timeout): self
{
$this->timeout = $timeout;

return $this;
}

public function getTimeout(): null|int|float
{
return $this->timeout;
}
}

0 comments on commit 690aad3

Please sign in to comment.