Improve test coverage

This commit is contained in:
2026-03-05 18:07:33 +00:00
parent 1e0bb0be8c
commit ccb8267813
8 changed files with 7550 additions and 49 deletions

View File

@@ -132,48 +132,53 @@ func (p *InstrumentedProvider) GenerateStream(ctx context.Context, messages []ap
defer close(outChan)
defer close(outErrChan)
// Helper function to record final metrics
recordMetrics := func() {
duration := time.Since(start).Seconds()
status := "success"
if streamErr != nil {
status = "error"
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(streamErr)
span.SetStatus(codes.Error, streamErr.Error())
}
} else {
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.SetAttributes(
attribute.Int64("provider.input_tokens", totalInputTokens),
attribute.Int64("provider.output_tokens", totalOutputTokens),
attribute.Int64("provider.chunk_count", chunkCount),
attribute.Float64("provider.ttfb_seconds", ttfb.Seconds()),
)
span.SetStatus(codes.Ok, "")
}
// Record token metrics
if p.registry != nil && (totalInputTokens > 0 || totalOutputTokens > 0) {
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "input").Add(float64(totalInputTokens))
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "output").Add(float64(totalOutputTokens))
}
}
// Record stream metrics
if p.registry != nil {
providerRequestsTotal.WithLabelValues(p.base.Name(), req.Model, "generate_stream", status).Inc()
providerStreamDuration.WithLabelValues(p.base.Name(), req.Model).Observe(duration)
providerStreamChunks.WithLabelValues(p.base.Name(), req.Model).Add(float64(chunkCount))
if ttfb > 0 {
providerStreamTTFB.WithLabelValues(p.base.Name(), req.Model).Observe(ttfb.Seconds())
}
}
}
for {
select {
case delta, ok := <-baseChan:
if !ok {
// Stream finished - record final metrics
duration := time.Since(start).Seconds()
status := "success"
if streamErr != nil {
status = "error"
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.RecordError(streamErr)
span.SetStatus(codes.Error, streamErr.Error())
}
} else {
if p.tracer != nil {
span := trace.SpanFromContext(ctx)
span.SetAttributes(
attribute.Int64("provider.input_tokens", totalInputTokens),
attribute.Int64("provider.output_tokens", totalOutputTokens),
attribute.Int64("provider.chunk_count", chunkCount),
attribute.Float64("provider.ttfb_seconds", ttfb.Seconds()),
)
span.SetStatus(codes.Ok, "")
}
// Record token metrics
if p.registry != nil && (totalInputTokens > 0 || totalOutputTokens > 0) {
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "input").Add(float64(totalInputTokens))
providerTokensTotal.WithLabelValues(p.base.Name(), req.Model, "output").Add(float64(totalOutputTokens))
}
}
// Record stream metrics
if p.registry != nil {
providerRequestsTotal.WithLabelValues(p.base.Name(), req.Model, "generate_stream", status).Inc()
providerStreamDuration.WithLabelValues(p.base.Name(), req.Model).Observe(duration)
providerStreamChunks.WithLabelValues(p.base.Name(), req.Model).Add(float64(chunkCount))
if ttfb > 0 {
providerStreamTTFB.WithLabelValues(p.base.Name(), req.Model).Observe(ttfb.Seconds())
}
}
recordMetrics()
return
}
@@ -198,8 +203,10 @@ func (p *InstrumentedProvider) GenerateStream(ctx context.Context, messages []ap
if ok && err != nil {
streamErr = err
outErrChan <- err
recordMetrics()
return
}
return
// If error channel closed without error, continue draining baseChan
}
}
}()