Skip to content

Commit

Permalink
Merge pull request #24 from bonsai-rx/python-string-formatter
Browse files Browse the repository at this point in the history
Bug fixes in python string formatting
  • Loading branch information
ncguilbeault authored Sep 20, 2024
2 parents cdcc7c2 + 0b999d8 commit 3a3bc3b
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 165 deletions.
182 changes: 113 additions & 69 deletions src/Bonsai.ML.Data/ArrayHelper.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Text;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Newtonsoft.Json;
Expand All @@ -13,92 +14,93 @@ namespace Bonsai.ML.Data
public static class ArrayHelper
{
/// <summary>
/// Serializes the input data into a JSON string representation.
/// Parses the input string into an object of the specified type.
/// If the input is a JSON array, the method will attempt to parse it into a list or array of the specified type.
/// </summary>
/// <param name="data">The data to serialize.</param>
/// <returns>A JSON string representation of the input data.</returns>
public static string SerializeToJson(object data)
/// <param name="input">The string to parse.</param>
/// <param name="dtype">The data type of the object.</param>
/// <returns>An object of the specified type containing the parsed data.</returns>
public static object ParseString(string input, Type dtype)
{
if (data is Array array)
if (!IsValidJson(input))
{
return SerializeArrayToJson(array);
throw new ArgumentException($"Parameter: {nameof(input)} is not valid JSON.");
}
else

var token = JsonConvert.DeserializeObject<JToken>(input);

if (token is JValue value)
{
return JsonConvert.SerializeObject(data);
return Convert.ChangeType(value, dtype);
}
}

/// <summary>
/// Serializes the input array into a JSON string representation.
/// </summary>
/// <param name="array">The array to serialize.</param>
/// <returns>A JSON string representation of the input array.</returns>
public static string SerializeArrayToJson(Array array)
{
StringBuilder sb = new StringBuilder();
SerializeArrayRecursive(array, sb, [0]);
return sb.ToString();
}
var output = ParseToken(token, dtype);

private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices)
{
if (indices.Length < array.Rank)
{
sb.Append("[");
int length = array.GetLength(indices.Length);
for (int i = 0; i < length; i++)
{
int[] newIndices = new int[indices.Length + 1];
indices.CopyTo(newIndices, 0);
newIndices[indices.Length] = i;
SerializeArrayRecursive(array, sb, newIndices);
if (i < length - 1)
{
sb.Append(", ");
}
}
sb.Append("]");
}
else
{
object value = array.GetValue(indices);
sb.Append(value.ToString());
}
return output;
}

private static bool IsValidJson(string input)
{
int squareBrackets = 0;
foreach (char c in input)
try
{
JToken.Parse(input);
return true;
}
catch
{
if (c == '[') squareBrackets++;
else if (c == ']') squareBrackets--;
return false;
}
return squareBrackets == 0;
}

/// <summary>
/// Parses the input JSON string into an object of the specified type. If the input is a JSON array, the method will attempt to parse it into an array of the specified type.
/// Parses the input token into an object of the specified type.
/// If the input is a JSON array, the method will attempt to parse it into a list or array of the specified type.
/// </summary>
/// <param name="input">The JSON string to parse.</param>
/// <param name="token">The token to parse.</param>
/// <param name="dtype">The data type of the object.</param>
/// <returns>An object of the specified type containing the parsed JSON data.</returns>
public static object ParseString(string input, Type dtype = null)
/// <returns>An object of the specified type containing the parsed data.</returns>
public static object ParseToken(JToken token, Type dtype)
{
if (!IsValidJson(input))
if (token is JValue value)
{
throw new ArgumentException($"Parameter: {nameof(input)} is not valid JSON.");
return Convert.ChangeType(value, dtype);
}
var obj = JsonConvert.DeserializeObject<JToken>(input);
int depth = ParseDepth(obj);
if (depth == 0)
else if (token is JArray)
{
return Convert.ChangeType(input, dtype);
if (token[0] is JValue)
{
if (token.All(item => item is JValue))
{
int depth = ParseDepth(token);
return ParseArray(token, dtype, depth);
}
return CreateList(token, dtype);
}
else
{
var subArrayDimensions = token.Cast<JArray>().Select(value => {
var depth = ParseDepth(value);
return ParseDimensions(value, depth);
}).ToList();

if (subArrayDimensions.All(s => s.SequenceEqual(subArrayDimensions[0])))
{
return ParseArray(token, dtype, subArrayDimensions[0].Count());
}
return CreateList(token, dtype);
}
}
else
{
throw new ArgumentException($"Error parsing parameter: {nameof(token)}. JSON input is not supported.");
}
int[] dimensions = ParseDimensions(obj, depth);
}

private static object ParseArray(JToken token, Type dtype, int depth)
{
int[] dimensions = ParseDimensions(token, depth);
var resultArray = Array.CreateInstance(dtype, dimensions);
PopulateArray(obj, resultArray, [0], dtype);
PopulateArray(token, resultArray, [], dtype);
return resultArray;
}

Expand All @@ -113,34 +115,32 @@ private static int ParseDepth(JToken token, int currentDepth = 0)

private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0)
{
if (depth == 0 || !(token is JArray))
if (depth == 0 || token is not JArray)
{
return [0];
}

List<int> dimensions = new List<int>();
JToken current = token;
List<int> dimensions = [];
var current = token;

while (current != null && current is JArray)
while (current != null && current is JArray currentArray)
{
JArray currentArray = current as JArray;
dimensions.Add(currentArray.Count);
if (currentArray.Count > 0)
{
if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count))
if (currentArray.Any(item => item is not JArray) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count))
{
throw new ArgumentException($"Error parsing parameter: {nameof(token)}. Array dimensions are inconsistent.");
}

if (!(currentArray.First() is JArray))
if (currentArray.First() is not JArray)
{
if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _)))
{
throw new ArgumentException($"Error parsing parameter: {nameof(token)}. All values in the array must be of the same type. Only numeric or boolean types are supported.");
}
}
}

current = currentArray.Count > 0 ? currentArray[0] : null;
}

Expand Down Expand Up @@ -180,5 +180,49 @@ private static void PopulateArray(JToken token, Array array, int[] indices, Type
array.SetValue(values, indices);
}
}

private static object CreateList(JToken token, Type dtype)
{
var listType = typeof(List<>).MakeGenericType(DetermineListType(token, dtype));
var list = (IList)Activator.CreateInstance(listType);

foreach (var item in token)
{
var result = ParseToken(item, dtype);
list.Add(result);
}

return list;
}

private static Type DetermineListType(JToken token, Type type)
{
if (token.All(item => item is JValue))
{
return type;
}
else if (token.All(item => item is JArray))
{
var subArrayDepth = token.Cast<JArray>().Select(value => ParseDepth(value)).ToList();

if (subArrayDepth.All(s => s == subArrayDepth[0]))
{
var rank = subArrayDepth[0];
if (rank > 1)
{
return type.MakeArrayType(rank);
}
return type.MakeArrayType();
}
else
{
return typeof(List<>).MakeGenericType(DetermineListType(token[0], type));
}
}
else
{
return typeof(object);
}
}
}
}
33 changes: 9 additions & 24 deletions src/Bonsai.ML.HiddenMarkovModels/PythonModel.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using Bonsai.ML.Data;
using Bonsai.ML.Python;
using System.Xml.Serialization;
using System.Linq;
using System.Reactive.Linq;
Expand Down Expand Up @@ -109,38 +111,21 @@ protected override string BuildString()
// StringBuilder.Clear();
StringBuilder.Append($"{ModelName}_model_type=\"{ModelType}\"");

if (Params != null && Params.Length > 0 && Params.All(p => p != null))
if (Params is not null && Params.Length > 0 && Params.All(param => param is not null))
{
var paramsStringBuilder = new StringBuilder();
paramsStringBuilder.Append($",{ModelName}_params=(");
StringBuilder.Append($",{ModelName}_params=(");

foreach (var param in Params) {
if (param is null) {
paramsStringBuilder.Clear();
break;
}
var arrString = param is Array array ? ArrayHelper.SerializeArrayToJson(array) : param.ToString();
paramsStringBuilder.Append($"{arrString},");
}

if (paramsStringBuilder.Length > 0) {
paramsStringBuilder.Remove(paramsStringBuilder.Length - 1, 1);
paramsStringBuilder.Append(")");
StringBuilder.Append(paramsStringBuilder);
StringBuilder.Append(StringFormatter.FormatToPython(param));
StringBuilder.Append(",");
}
StringBuilder.Append(")");
}

if (Kwargs is not null && Kwargs.Count > 0)
{
StringBuilder.Append($",{ModelName}_kwargs={{");
foreach (var kp in Kwargs) {
StringBuilder.Append($"\"{kp.Key}\":{(kp.Value is null ? "None"
: kp.Value is Array array ? ArrayHelper.SerializeArrayToJson(array)
: kp.Value is string ? $"\"{kp.Value}\""
: kp.Value)},");
}
StringBuilder.Remove(StringBuilder.Length - 1, 1);
StringBuilder.Append("}");
StringBuilder.Append($",{ModelName}_kwargs=");
StringBuilder.Append(StringFormatter.FormatToPython(Kwargs));
}

var result = StringBuilder.ToString();
Expand Down
2 changes: 1 addition & 1 deletion src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ protected override string BuildString()

if (InitialStateDistribution != null)
{
StringBuilder.Append($"initial_state_distribution={ArrayHelper.SerializeToJson(InitialStateDistribution)},");
StringBuilder.Append($"initial_state_distribution={StringFormatter.FormatToPython(InitialStateDistribution)},");
}

if (Transitions != null)
Expand Down
18 changes: 2 additions & 16 deletions src/Bonsai.ML.HiddenMarkovModels/StateParametersJsonConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,7 @@ public override StateParameters ReadJson(JsonReader reader, Type objectType, Sta
transitionsParamsArray = new object[nParams];
for (int i = 0; i < nParams; i++)
{
try
{
transitionsParamsArray[i] = ArrayHelper.ParseString(paramsJArray[i].ToString(), typeof(double));
}
catch
{
transitionsParamsArray[i] = JsonConvert.DeserializeObject(paramsJArray[i].ToString());
}
transitionsParamsArray[i] = ArrayHelper.ParseToken(paramsJArray[i], typeof(double));
}
}

Expand Down Expand Up @@ -90,14 +83,7 @@ public override StateParameters ReadJson(JsonReader reader, Type objectType, Sta
observationsParamsArray = new object[nParams];
for (int i = 0; i < nParams; i++)
{
try
{
observationsParamsArray[i] = ArrayHelper.ParseString(paramsJArray[i].ToString(), typeof(double));
}
catch
{
observationsParamsArray[i] = JsonConvert.DeserializeObject(paramsJArray[i].ToString());
}
observationsParamsArray[i] = ArrayHelper.ParseToken(paramsJArray[i], typeof(double));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace Bonsai.ML.HiddenMarkovModels.Transitions
[JsonObject(MemberSerialization.OptIn)]
public class ConstrainedStationaryTransitions : TransitionsModel
{
private int[,] transitionMask = null;
private int[,] transitionMask = new int[,] { { 1, 1 }, { 1, 1 } };

/// <summary>
/// The mask which gets applied to the transition matrix to prohibit certain transitions.
Expand All @@ -30,7 +30,7 @@ public class ConstrainedStationaryTransitions : TransitionsModel
[Description("The mask which gets applied to the transition matrix to prohibit certain transitions. It must be written in JSON format as an int[,] with the same shape as the transition matrix (nStates x nStates). For example, the mask [[1, 0], [1, 1]] is valid and would prohibit transitions from state 0 to state 1.")]
public string TransitionMask
{
get => transitionMask != null ? ArrayHelper.SerializeToJson(transitionMask) : "";
get => transitionMask != null ? StringFormatter.FormatToPython(transitionMask) : "[[1, 1], [1, 1]]";
set => transitionMask = (int[,])ArrayHelper.ParseString(value, typeof(int));
}

Expand Down Expand Up @@ -94,7 +94,7 @@ protected override void UpdateKwargs(params object[] kwargs)
int[,] mask => mask,
long[,] mask => ConvertLongArrayToIntArray(mask),
bool[,] mask => ConvertBoolArrayToIntArray(mask),
_ => null
var mask => (int[,])ArrayHelper.ParseString(mask.ToString(), typeof(int))
};
}

Expand Down
Loading

0 comments on commit 3a3bc3b

Please sign in to comment.