Skip to content

Commit

Permalink
Treat TensorFlow output as non-batched. (#5634)
Browse files Browse the repository at this point in the history
* Can now not treat output as batched.

* updated comments based on PR comments.

* Fixing saving/loading with new parameter.

* Updates based on PR comments

* Update src/Microsoft.ML.TensorFlow/TensorflowUtils.cs

Co-authored-by: Eric Erhardt <[email protected]>

* reverted accidental test changes

* fixes based on PR comments

Co-authored-by: Eric Erhardt <[email protected]>
  • Loading branch information
michaelgsharp and eerhardt authored Mar 3, 2021
1 parent f93fa09 commit 447ae1d
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 35 deletions.
9 changes: 6 additions & 3 deletions src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public sealed class TensorFlowModel : IDisposable
{
internal Session Session { get; }
internal string ModelPath { get; }
internal bool TreatOutputAsBatched { get; }

private readonly IHostEnvironment _env;

Expand All @@ -27,10 +28,12 @@ public sealed class TensorFlowModel : IDisposable
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
/// <param name="session">TensorFlow session object.</param>
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation)
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation, bool treatOutputAsBatched = true)
{
Session = session;
ModelPath = modelLocation;
TreatOutputAsBatched = treatOutputAsBatched;
_env = env;
_disposed = false;
}
Expand All @@ -40,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
/// </summary>
public DataViewSchema GetModelSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph);
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched);
}

/// <summary>
Expand All @@ -49,7 +52,7 @@ public DataViewSchema GetModelSchema()
/// </summary>
public DataViewSchema GetInputSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder");
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched, "Placeholder");
}

/// <summary>
Expand Down
26 changes: 26 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,31 @@ public static class TensorflowCatalog
/// </example>
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation)
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation);

/// <summary>
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string, bool)"/>.
/// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
/// <see cref="TensorFlowModel"/> also holds references to unmanaged resources that need to be freed either with an explicit
/// call to Dispose() or implicitly by declaring the variable with the "using" syntax/>
///
/// <format type="text/markdown">
/// <![CDATA[
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
/// ]]>
/// </format>
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="modelLocation">Location of the TensorFlow model.</param>
/// <param name="treatOutputAsBatched">If the first dimension of the output is unknown, should it be treated as batched or not.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[LoadTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs)]
/// ]]>
/// </format>
/// </example>
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation, bool treatOutputAsBatched)
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation, treatOutputAsBatched);
}
}
Loading

0 comments on commit 447ae1d

Please sign in to comment.