Skip to content

Commit

Permalink
Merge pull request #12 from GeraudBourdin/main
Browse files Browse the repository at this point in the history
Add vllm guided 'guided_json' option
  • Loading branch information
GeraudBourdin authored Apr 10, 2024
2 parents 581eaf2 + e3e294f commit 226ec91
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 105 deletions.
43 changes: 34 additions & 9 deletions src/MistralClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ class MistralClient
const string TOOL_CHOICE_AUTO = 'auto';
const string TOOL_CHOICE_NONE = 'none';

const array RETRY_STATUS_CODES = [429, 500 => GenericRetryStrategy::IDEMPOTENT_METHODS, 502, 503, 504 => GenericRetryStrategy::IDEMPOTENT_METHODS];
const array RETRY_STATUS_CODES = [
429,
500 => GenericRetryStrategy::IDEMPOTENT_METHODS,
502,
503,
504 => GenericRetryStrategy::IDEMPOTENT_METHODS
];
protected const string END_OF_STREAM = "[DONE]";
const string ENDPOINT = 'https://api.mistral.ai';
public const string CHAT_ML = 'mistral';
Expand All @@ -33,7 +39,10 @@ class MistralClient

public function __construct(string $apiKey, string $url = self::ENDPOINT)
{
$this->httpClient = new RetryableHttpClient(HttpClient::create(), new GenericRetryStrategy(self::RETRY_STATUS_CODES, 500, 2));
$this->httpClient = new RetryableHttpClient(
HttpClient::create(),
new GenericRetryStrategy(self::RETRY_STATUS_CODES, 500, 2)
);
$this->apiKey = $apiKey;
$this->url = $url;
$this->mode = self::CHAT_ML;
Expand All @@ -60,10 +69,18 @@ public function listModels(): array
/**
* @throws MistralClientException
*/
protected function request(string $method, string $path, array $request = [], bool $stream = false): array|ResponseInterface
{
protected function request(
string $method,
string $path,
array $request = [],
bool $stream = false
): array|ResponseInterface {
try {
$response = $this->httpClient->request($method, $this->url . '/' . $path, ['json' => $request, 'headers' => ['Authorization' => 'Bearer ' . $this->apiKey,], 'buffer' => $stream,]);
$response = $this->httpClient->request(
$method,
$this->url . '/' . $path,
['json' => $request, 'headers' => ['Authorization' => 'Bearer ' . $this->apiKey,], 'buffer' => $stream,]
);
} catch (Throwable $e) {
throw new MistralClientException($e->getMessage(), $e->getCode(), $e);
}
Expand All @@ -73,7 +90,6 @@ protected function request(string $method, string $path, array $request = [], bo
} catch (Throwable $e) {
throw new MistralClientException($e->getMessage(), $e->getCode(), $e);
}

}

/**
Expand Down Expand Up @@ -160,12 +176,14 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
$return['n'] = $params['n'];
}

if (isset($params['presence_penalty']) && is_numeric($params['presence_penalty']) && $params['presence_penalty'] >= -2 && $params['presence_penalty'] <= 2) {
$return['presence_penalty'] = (float) $params['presence_penalty'];
if (isset($params['presence_penalty']) && is_numeric(
$params['presence_penalty']
) && $params['presence_penalty'] >= -2 && $params['presence_penalty'] <= 2) {
$return['presence_penalty'] = (float)$params['presence_penalty'];
}

if (isset($params['frequency_penalty']) && is_numeric($params['frequency_penalty'])) {
$return['frequency_penalty'] = (float) $params['frequency_penalty'];
$return['frequency_penalty'] = (float)$params['frequency_penalty'];
}

if (isset($params['best_of']) && is_int($params['best_of'])) {
Expand All @@ -184,6 +202,13 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
$return['skip_special_tokens'] = $params['skip_special_tokens'];
}

if (isset($params['guided_json']) && is_string($params['guided_json'])) {
$return['guided_json'] = $params['guided_json'];
}

if (isset($params['guided_json']) && is_object($params['guided_json'])) {
$return['guided_json'] = json_encode($params['guided_json']);
}

return $return;
}
Expand Down
199 changes: 103 additions & 96 deletions src/Response.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,98 @@ public function __construct()
$this->choices = new ArrayObject();
}

public static function createFromArray(array $data): self
{
$response = new self();
return self::updateFromArray($response, $data);
}

public static function updateFromArray(self $response, array $data): self
{
if (isset($data['id'])) {
$response->setId($data['id']);
}

if (isset($data['object'])) {
$response->setObject($data['object']);
}

if (isset($data['created'])) {
$response->setCreated($data['created']);
}

if (isset($data['model'])) {
$response->setModel($data['model']);
}


$message = $response->getChoices()->count() > 0 ? $response->getChoices()[$response->getChoices()->count(
) - 1] : new Message();
// Llama.cpp response
if (isset($data['content'])) {
$message->setRole('assistant');
if (isset($data['stream']) && $data['stream'] === true) {
$message->updateContent($data['content']);
$message->setChunk($data['content']);
} else {
$message->setContent($data['content']);
}
$response->addMessage($message);
}

// Mistral platform response
if (isset($data['choices'])) {
foreach ($data['choices'] as $choice) {
if (isset($choice['message']['role'])) {
$message->setRole($choice['message']['role']);
}

if (isset($choice['message']['content'])) {
$message->setContent($choice['message']['content']);
}

if (isset($choice['delta']['role'])) {
$message->setRole($choice['delta']['role']);
}

if (isset($choice['delta']['content'])) {
$message->updateContent($choice['delta']['content']);
$message->setChunk($choice['delta']['content']);
}

if (isset($choice['message']['tool_calls'])) {
$message->setToolCalls($choice['message']['tool_calls']);
}

if ($response->getChoices()->count() === 0) {
$response->addMessage($message);
}
}
}
if (isset($data['usage'])) {
$response->setUsage($data['usage']);
}

return $response;
}

public function getChoices(): ArrayObject
{
return $this->choices;
}

public function setChoices(ArrayObject $choices): self
{
$this->choices = $choices;
return $this;
}

public function addMessage(Message $message): self
{
$this->choices->append($message);

return $this;
}

public function getId(): string
{
Expand Down Expand Up @@ -76,117 +168,32 @@ public function setUsage(array $usage): self
return $this;
}


public function getChoices(): ArrayObject
public function getToolCalls(): ?array
{
return $this->choices;
return $this->choices->count() === 0 ? null : $this->getLastMessage()->getToolCalls();
}

public function setChoices(ArrayObject $choices): self
private function getLastMessage(): Message
{
$this->choices = $choices;
return $this;
return $this->choices[$this->choices->count() - 1];
}

public function addMessage(Message $message): self
public function getChunk(): ?string
{
$this->choices->append($message);

return $this;
return $this->choices->count() === 0 ? null : $this->getLastMessage()->getChunk();
}


public static function updateFromArray(self $response, array $data): self
public function getGuidedMessage(?bool $associative = null): null|object|array
{
if(isset($data['id'])) {
$response->setId($data['id']);
}

if(isset($data['object'])) {
$response->setObject($data['object']);
}

if(isset($data['created'])) {
$response->setCreated($data['created']);
if (is_string($this->getMessage()) && json_validate($this->getMessage())) {
return json_decode($this->getMessage($associative));
}

if(isset($data['model'])) {
$response->setModel($data['model']);
}


$message = $response->getChoices()->count() > 0 ? $response->getChoices()[$response->getChoices()->count() - 1] : new Message();
// Llama.cpp response
if(isset($data['content'])) {
$message->setRole('assistant');
if(isset($data['stream']) && $data['stream'] === true) {
$message->updateContent($data['content']);
$message->setChunk($data['content']);
} else {
$message->setContent($data['content']);
}
$response->addMessage($message);
}

// Mistral platform response
if(isset($data['choices'])) {
foreach ($data['choices'] as $choice) {

if(isset($choice['message']['role'])) {
$message->setRole($choice['message']['role']);
}

if(isset($choice['message']['content'])) {
$message->setContent($choice['message']['content']);
}

if(isset($choice['delta']['role'])) {
$message->setRole($choice['delta']['role']);
}

if(isset($choice['delta']['content'])) {
$message->updateContent($choice['delta']['content']);
$message->setChunk($choice['delta']['content']);
}

if(isset($choice['message']['tool_calls'])) {
$message->setToolCalls($choice['message']['tool_calls']);
}

if($response->getChoices()->count()===0) {
$response->addMessage($message);
}
}
}
if(isset($data['usage'])) {
$response->setUsage($data['usage']);
}

return $response;
}
public static function createFromArray(array $data): self
{
$response = new self();
return self::updateFromArray($response, $data);
return null;
}

public function getMessage(): ?string
{
return $this->choices->count() === 0 ? null : $this->getLastMessage()->getContent();
}

public function getToolCalls(): ?array
{
return $this->choices->count() === 0 ? null : $this->getLastMessage()->getToolCalls();
}

public function getChunk(): ?string
{
return $this->choices->count() === 0 ? null : $this->getLastMessage()->getChunk();
}

private function getLastMessage(): Message
{
return $this->choices[$this->choices->count() - 1];
return $this->choices->count() === 0 ? null : $this->getLastMessage()->getContent();
}
}

0 comments on commit 226ec91

Please sign in to comment.