Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class PersistingQueryEventConsumer(
query: Query?,
clientQueryId: String,
message: String,
anonymousTypes: Set<Type>
anonymousTypes: Set<Type>,
username: String?
) {
handleEvent(
QueryStartEvent(
Expand All @@ -102,7 +103,8 @@ class PersistingQueryEventConsumer(
persistRemoteCallResponses = config.persistRemoteCallResponses,
persistRemoteCallMetadata = config.persistRemoteCallMetadata,
persistTraceEvents = config.persistTraceEvents,
persistErrors = config.persistErrors
persistErrors = config.persistErrors,
username = username
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import java.time.Instant
class QueryLifecycleEventObserver(
private val consumer: QueryEventConsumer,
private val activeQueryMonitor: ActiveQueryMonitor?,
private val username: String? = null
) {
companion object {
private val logger = KotlinLogging.logger {}
Expand Down Expand Up @@ -64,7 +65,8 @@ class QueryLifecycleEventObserver(
queryId = queryResult.queryId,
clientQueryId = queryResult.clientQueryId ?: queryResult.queryId,
timestamp = queryStartTime,
anonymousTypes = queryResult.anonymousTypes
anonymousTypes = queryResult.anonymousTypes,
username = username
)

return queryResult.copy(
Expand Down Expand Up @@ -149,7 +151,8 @@ class QueryLifecycleEventObserver(
queryId = queryResult.queryId,
clientQueryId = queryResult.clientQueryId ?: queryResult.queryId,
timestamp = queryStartTime,
anonymousTypes = queryResult.anonymousTypes
anonymousTypes = queryResult.anonymousTypes,
username = username
)

return queryResult.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ class QueryService(
query = null,
clientQueryId = clientQueryId ?: "",
message = "",
anonymousTypes = emptySet()
anonymousTypes = emptySet(),
username = vyneUser?.username
)
val failedSearchResponse = FailedSearchResponse(
message = e.message!!, // Message contains the error messages from the compiler
Expand All @@ -610,7 +611,7 @@ class QueryService(
} catch (e: Exception) {
FailedSearchResponse(e.message!!, null, queryId = queryId)
}
QueryLifecycleEventObserver(historyWriterEventConsumer, activeQueryMonitor)
QueryLifecycleEventObserver(historyWriterEventConsumer, activeQueryMonitor, username = vyneUser?.username)
.responseWithQueryHistoryListener(query, response) to queryOptions
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@ import com.orbitalhq.models.json.right
import com.orbitalhq.query.QueryEvent
import com.orbitalhq.query.QueryEventConsumer
import com.orbitalhq.query.QueryResult
import com.orbitalhq.query.QueryStartEvent
import com.orbitalhq.query.StreamingQueryCancelledEvent
import com.orbitalhq.query.TaxiQlQueryResultEvent
import com.orbitalhq.testVyne
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.test.runTest
import mu.KotlinLogging
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Test
import reactor.core.publisher.Flux
import reactor.core.publisher.Sinks
Expand Down Expand Up @@ -51,6 +55,61 @@ class QueryLifecycleEventObserverTest {
operation streamClients():Stream<Client>
}
""".trimIndent()
@Test
fun `username is included in QueryStartEvent when observer is constructed with a username`() = runTest {
val capturedEvents = mutableListOf<QueryEvent>()
val capturingConsumer = object : QueryEventConsumer {
override fun handleEvent(event: QueryEvent) {
capturedEvents.add(event)
}
override fun recordResult(operation: OperationResult, queryId: String) {}
}

val (vyne, stub) = testVyne(taxiDef)
val typedInstance = TypedInstance.from(
vyne.type("Client"),
mapOf("clientId" to "123", "clientName" to "Marty"),
vyne.schema
)
stub.addResponseFlow("streamClients") { _, _ -> flowOf(typedInstance.right()) }

val queryResult = vyne.query("stream { Client }")

// captureQueryStart is called eagerly by responseWithQueryHistoryListener before the flow is consumed
QueryLifecycleEventObserver(capturingConsumer, null, username = "marty.mcfly")
.responseWithQueryHistoryListener("stream { Client }", queryResult)

val startEvent = capturedEvents.filterIsInstance<QueryStartEvent>().firstOrNull()
assertEquals("marty.mcfly", startEvent?.username)
}

@Test
fun `username is null in QueryStartEvent when observer has no authenticated user`() = runTest {
val capturedEvents = mutableListOf<QueryEvent>()
val capturingConsumer = object : QueryEventConsumer {
override fun handleEvent(event: QueryEvent) {
capturedEvents.add(event)
}
override fun recordResult(operation: OperationResult, queryId: String) {}
}

val (vyne, stub) = testVyne(taxiDef)
val typedInstance = TypedInstance.from(
vyne.type("Client"),
mapOf("clientId" to "123", "clientName" to "Marty"),
vyne.schema
)
stub.addResponseFlow("streamClients") { _, _ -> flowOf(typedInstance.right()) }

val queryResult = vyne.query("stream { Client }")

QueryLifecycleEventObserver(capturingConsumer, null, username = null)
.responseWithQueryHistoryListener("stream { Client }", queryResult)

val startEvent = capturedEvents.filterIsInstance<QueryStartEvent>().firstOrNull()
assertNull(startEvent?.username)
}

@OptIn(ExperimentalCoroutinesApi::class)
@Test
fun `When streaming query is cancelled StreamingQueryCancelledEvent event is published`() = runTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class ShutdownDecorator(private val delegate: QueryEventConsumer, val onShutdown
query: Query?,
clientQueryId: String,
message: String,
anonymousTypes: Set<Type>
) = delegate.captureQueryStart(queryId, timestamp, taxiQuery, query, clientQueryId, message, anonymousTypes)
anonymousTypes: Set<Type>,
username: String?
) = delegate.captureQueryStart(queryId, timestamp, taxiQuery, query, clientQueryId, message, anonymousTypes, username)

override fun recordResult(operation: OperationResult, queryId: String) = delegate.recordResult(operation, queryId)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ interface QueryEventConsumer : RemoteCallOperationResultHandler {
query: Query?,
clientQueryId: String,
message: String,
anonymousTypes: Set<Type>
anonymousTypes: Set<Type>,
username: String? = null
) {
handleEvent(
QueryStartEvent(
Expand All @@ -26,7 +27,8 @@ interface QueryEventConsumer : RemoteCallOperationResultHandler {
query = query,
clientQueryId = clientQueryId,
message = message,
anonymousTypes = anonymousTypes
anonymousTypes = anonymousTypes,
username = username
)
)
}
Expand Down Expand Up @@ -130,6 +132,7 @@ data class QueryStartEvent(
val persistRemoteCallResponses: Boolean? = null,
val persistRemoteCallMetadata: Boolean? = null,
val persistTraceEvents: Boolean? = null,
val persistErrors: Boolean? = null
val persistErrors: Boolean? = null,
val username: String? = null
) : QueryEvent(isTerminalEvent = false)

Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ data class QuerySummary(
@Column(name = "persist_trace_events")
val persistTraceEvents: Boolean? = null,
@Column(name = "persist_errors")
val persistErrors: Boolean? = null
val persistErrors: Boolean? = null,
@Column(name = "username")
val username: String? = null
) : VyneHistoryRecord() {
@Transient
var durationMs = endTime?.let { Duration.between(startTime, endTime).toMillis() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ object QueryResultEventMapper {
persistRemoteCallResponses = event.persistRemoteCallResponses,
persistRemoteCallMetadata = event.persistRemoteCallMetadata,
persistTraceEvents = event.persistTraceEvents,
persistErrors = event.persistErrors
persistErrors = event.persistErrors,
username = event.username
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE query_summary ADD COLUMN username VARCHAR(255);
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import com.orbitalhq.schemas.fqn
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.jose4j.jwk.RsaJwkGenerator
import org.jose4j.jws.AlgorithmIdentifiers
import org.jose4j.jws.JsonWebSignature
import org.jose4j.jwt.JwtClaims
import org.junit.Test
import org.junit.runner.RunWith
import org.springframework.beans.factory.annotation.Autowired
Expand All @@ -34,6 +38,9 @@ import org.springframework.test.context.bean.override.mockito.MockitoBean
import org.springframework.boot.testcontainers.service.connection.ServiceConnection
import org.springframework.context.annotation.Import
import org.springframework.http.MediaType
import org.springframework.security.core.Authentication
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken
import org.springframework.test.context.ActiveProfiles
import org.springframework.test.context.junit4.SpringRunner
import org.testcontainers.containers.PostgreSQLContainer
Expand Down Expand Up @@ -110,6 +117,84 @@ class QuerySummaryOnlyPersistenceTest : BaseQueryServiceTest() {
lateinit var schemaEditorService: SchemaEditorService


private fun authenticationWithPreferredUsername(preferredUsername: String): Authentication {
val rsaJsonWebKey = RsaJwkGenerator.generateJwk(2048)
rsaJsonWebKey.apply {
keyId = UUID.randomUUID().toString()
algorithm = AlgorithmIdentifiers.RSA_USING_SHA256
use = "sig"
}
val claims = JwtClaims().apply {
jwtId = UUID.randomUUID().toString()
issuer = "https://test.example.com"
subject = UUID.randomUUID().toString()
setExpirationTimeMinutesInTheFuture(10F)
setIssuedAtToNow()
setClaim("preferred_username", preferredUsername)
}
val jwt = JsonWebSignature().apply {
payload = claims.toJson()
key = rsaJsonWebKey.privateKey
algorithmHeaderValue = rsaJsonWebKey.algorithm
keyIdHeaderValue = rsaJsonWebKey.keyId
setHeader("typ", "JWT")
}.compactSerialization
return JwtAuthenticationToken(
NimbusReactiveJwtDecoder.withPublicKey(rsaJsonWebKey.getRsaPublicKey()).build().decode(jwt).block()
)
}

@Test
fun `username is persisted in query summary when query is submitted with authentication`() {
setupTestService(historyDbWriter)
val id = UUID.randomUUID().toString()
val auth = authenticationWithPreferredUsername("marty.mcfly")

runTest {
val turbine = queryService.submitVyneQlQueryStreamingResponse(
"find { Order[] } as Report[]",
auth = auth,
clientQueryId = id
).testIn(this)

val first = turbine.awaitItem()
first.should.not.be.`null`
turbine.awaitComplete()
}

Awaitility.await().atMost(com.jayway.awaitility.Duration.TEN_SECONDS).until {
queryHistoryRecordRepository.findByClientQueryId(id)?.endTime != null
}

val historyRecord = queryHistoryRecordRepository.findByClientQueryId(id)!!
historyRecord.username.should.equal("marty.mcfly")
}

@Test
fun `username is null in query summary when no authentication is provided`() {
setupTestService(historyDbWriter)
val id = UUID.randomUUID().toString()

runTest {
val turbine = queryService.submitVyneQlQueryStreamingResponse(
"find { Order[] } as Report[]",
auth = null,
clientQueryId = id
).testIn(this)

val first = turbine.awaitItem()
first.should.not.be.`null`
turbine.awaitComplete()
}

Awaitility.await().atMost(com.jayway.awaitility.Duration.TEN_SECONDS).until {
queryHistoryRecordRepository.findByClientQueryId(id)?.endTime != null
}

val historyRecord = queryHistoryRecordRepository.findByClientQueryId(id)!!
historyRecord.username.should.be.`null`
}

@Test
fun `Only Query Summary is persisted when vyne history persistResults is false for a taxiQl query`() {
setupTestService(historyDbWriter)
Expand Down