From f7d61a5faeb1ec4200b30985cce76b9dd3c30d83 Mon Sep 17 00:00:00 2001 From: Francisco Castillo Date: Mon, 28 Nov 2022 17:01:34 +0100 Subject: [PATCH 1/4] REFACTOR - flags --- cmd/config.go | 9 +-- cmd/get.go | 10 +-- cmd/transcribe.go | 202 +++++++++++++++++++++------------------------- schemas/types.go | 9 ++- 4 files changed, 102 insertions(+), 128 deletions(-) diff --git a/cmd/config.go b/cmd/config.go index 76f0b4f..1218b87 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -16,17 +16,14 @@ var configCmd = &cobra.Command{ Short: "Authenticate the CLI", Long: `This command will validate your account and store your token safely, later to be used when transcribing files.`, Run: func(cmd *cobra.Command, args []string) { - - argsArray := cmd.Flags().Args() - - if len(argsArray) == 0 { + if len(args) == 0 { fmt.Println("Please provide a token. If you don't have one, create an account at https://app.assemblyai.com") return - } else if len(argsArray) > 1 { + } else if len(args) > 1 { fmt.Println("Too many arguments. Please provide a single token.") return } - U.Token = argsArray[0] + U.Token = args[0] checkToken := U.CheckIfTokenValid() if !checkToken { diff --git a/cmd/get.go b/cmd/get.go index 0dbbdae..9769f68 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -11,15 +11,12 @@ import ( "github.com/spf13/cobra" ) -// get represents the getTranscription command var getCmd = &cobra.Command{ Use: "get [transcription_id]", Short: "Get a transcription", Long: `After submitting a file for transcription, you can fetch it by passing its ID.`, Args: cobra.MinimumNArgs(1), Run: func(cmd *cobra.Command, args []string) { - var flags S.TranscribeFlags - args = cmd.Flags().Args() if len(args) == 0 { printErrorProps := S.PrintErrorProps{ Error: errors.New("No transcription ID provided."), @@ -29,8 +26,6 @@ var getCmd = &cobra.Command{ return } id := args[0] - flags.Poll, _ = cmd.Flags().GetBool("poll") - flags.Json, _ = cmd.Flags().GetBool("json") U.Token = U.GetStoredToken() if U.Token == "" { @@ -48,6 +43,7 @@ var getCmd = &cobra.Command{ func init() { rootCmd.AddCommand(getCmd) - getCmd.Flags().BoolP("json", "j", false, "If true, the CLI will output the JSON.") - getCmd.Flags().BoolP("poll", "p", true, "The CLI will poll the transcription until it's complete.") + getCmd.PersistentFlags().BoolVarP(&flags.Poll, "poll", "p", true, "The CLI will poll the transcription until it's complete.") + getCmd.PersistentFlags().BoolVarP(&flags.Json, "json", "j", false, "If true, the CLI will output the JSON.") + getCmd.PersistentFlags().StringVar(&flags.Csv, "csv", "", "Specify the filename to save the transcript result onto a .CSV file extension") } diff --git a/cmd/transcribe.go b/cmd/transcribe.go index 9304f1c..df7c237 100644 --- a/cmd/transcribe.go +++ b/cmd/transcribe.go @@ -9,13 +9,15 @@ import ( "fmt" "io/ioutil" "os" - "strings" S "github.com/AssemblyAI/assemblyai-cli/schemas" U "github.com/AssemblyAI/assemblyai-cli/utils" "github.com/spf13/cobra" ) +var flags S.TranscribeFlags +var params S.TranscribeParams + var transcribeCmd = &cobra.Command{ Use: "transcribe ", Short: "Transcribe and understand audio with a single AI-powered API", @@ -24,10 +26,6 @@ var transcribeCmd = &cobra.Command{ Powered by cutting-edge AI models.`, Args: cobra.MinimumNArgs(1), Run: func(cmd *cobra.Command, args []string) { - var params S.TranscribeParams - var flags S.TranscribeFlags - - args = cmd.Flags().Args() if len(args) == 0 { printErrorProps := S.PrintErrorProps{ Error: errors.New("Please provide a URL, path, or YouTube URL"), @@ -38,39 +36,28 @@ var transcribeCmd = &cobra.Command{ } params.AudioURL = args[0] - flags.Json, _ = cmd.Flags().GetBool("json") - flags.Poll, _ = cmd.Flags().GetBool("poll") - params.AutoChapters, _ = cmd.Flags().GetBool("auto_chapters") - params.AutoHighlights, _ = cmd.Flags().GetBool("auto_highlights") - params.ContentModeration, _ = cmd.Flags().GetBool("content_moderation") - params.DualChannel, _ = cmd.Flags().GetBool("dual_channel") - params.EntityDetection, _ = cmd.Flags().GetBool("entity_detection") - params.FormatText, _ = cmd.Flags().GetBool("format_text") - params.Punctuate, _ = cmd.Flags().GetBool("punctuate") - params.RedactPii, _ = cmd.Flags().GetBool("redact_pii") - params.SentimentAnalysis, _ = cmd.Flags().GetBool("sentiment_analysis") - params.SpeakerLabels, _ = cmd.Flags().GetBool("speaker_labels") - params.TopicDetection, _ = cmd.Flags().GetBool("topic_detection") - params.Summarization, _ = cmd.Flags().GetBool("summarization") - wordBoost, _ := cmd.Flags().GetString("word_boost") - if wordBoost != "" { - params.WordBoost = strings.Split(wordBoost, ",") - boostParam, _ := cmd.Flags().GetString("boost_param") - if boostParam != "" && boostParam != "low" && boostParam != "default" && boostParam != "high" { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid boost_param"), - Message: "Please provide a valid boost_param. Valid values are low, default, or high.", - } - U.PrintError(printErrorProps) - return + if params.WordBoost == nil && params.BoostParam != "" { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Please provide a valid word boost"), + Message: "To boost a word, please provide a valid list of words to boost. For example: --word_boost \"word1,word2,word3\" --boost_param high", } - params.BoostParam = &boostParam + U.PrintError(printErrorProps) + return + } else if params.BoostParam != "" && params.BoostParam != "low" && params.BoostParam != "default" && params.BoostParam != "high" { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid boost_param"), + Message: "Please provide a valid boost_param. Valid values are low, default, or high.", + } + U.PrintError(printErrorProps) + return } - if params.Summarization { + + if !params.Summarization { + params.SummaryType = "" + params.SummaryModel = "" + } else { params.Punctuate = true params.FormatText = true - - params.SummaryType, _ = cmd.Flags().GetString("summary_type") if _, ok := S.SummarizationTypeMapReverse[params.SummaryType]; !ok { printErrorProps := S.PrintErrorProps{ Error: errors.New("Invalid summary type"), @@ -79,41 +66,35 @@ var transcribeCmd = &cobra.Command{ U.PrintError(printErrorProps) return } - summaryModel, _ := cmd.Flags().GetString("summary_model") - if summaryModel != "" { - if _, ok := S.SummarizationModelMap[summaryModel]; !ok { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid summary model"), - Message: "Invalid summary model. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return + if _, ok := S.SummarizationModelMap[params.SummaryModel]; !ok { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid summary model"), + Message: "Invalid summary model. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", } - if !U.Contains(S.SummarizationModelMap[summaryModel], params.SummaryType) { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid summary model"), - Message: "Cant use summary model " + summaryModel + " with summary type " + params.SummaryType + ". To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return + U.PrintError(printErrorProps) + return + } + if !U.Contains(S.SummarizationModelMap[params.SummaryModel], params.SummaryType) { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid summary model"), + Message: "Cant use summary model " + params.SummaryModel + " with summary type " + params.SummaryType + ". To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", } - if summaryModel == "conversational" && !params.SpeakerLabels { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Speaker labels required for conversational summary model"), - Message: "Speaker labels are required for conversational summarization. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return + U.PrintError(printErrorProps) + return + } + if params.SummaryModel == "conversational" && !params.SpeakerLabels { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Speaker labels required for conversational summary model"), + Message: "Speaker labels are required for conversational summarization. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", } - params.SummaryModel = summaryModel + U.PrintError(printErrorProps) + return } } - - if params.RedactPii { - policies, _ := cmd.Flags().GetString("redact_pii_policies") - policiesArray := strings.Split(policies, ",") - - for _, policy := range policiesArray { + if !params.RedactPii { + params.RedactPiiPolicies = nil + } else { + for _, policy := range params.RedactPiiPolicies { if _, ok := S.PIIRedactionPolicyMap[policy]; !ok { printErrorProps := S.PrintErrorProps{ Error: errors.New("Invalid redaction policy"), @@ -123,24 +104,17 @@ var transcribeCmd = &cobra.Command{ return } } - - params.RedactPiiPolicies = policiesArray } - webhook := cmd.Flags().Lookup("webhook_url").Value.String() - if webhook != "" { - params.WebhookURL = webhook - webhookHeaderName := cmd.Flags().Lookup("webhook_auth_header_name").Value.String() - webhookHeaderValue := cmd.Flags().Lookup("webhook_auth_header_value").Value.String() - if webhookHeaderName != "" { - params.WebhookAuthHeaderName = webhookHeaderName - } - if webhookHeaderValue != "" { - params.WebhookAuthHeaderValue = webhookHeaderValue + + if params.LanguageDetection && params.LanguageCode != "" { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Language detection and language code cannot be used together"), + Message: "Language detection and language code cannot be used together.", } + U.PrintError(printErrorProps) + return } - languageDetection, _ := cmd.Flags().GetBool("language_detection") - languageCode, _ := cmd.Flags().GetString("language_code") - if (languageCode != "" || languageDetection) && params.SpeakerLabels { + if (params.LanguageCode != "" || params.LanguageDetection) && params.SpeakerLabels { if cmd.Flags().Lookup("speaker_labels").Changed { printErrorProps := S.PrintErrorProps{ Error: errors.New("Speaker labels are not supported for languages other than English"), @@ -152,11 +126,8 @@ var transcribeCmd = &cobra.Command{ params.SpeakerLabels = false } } - if languageDetection && languageCode == "" { - params.LanguageDetection = true - } - if languageCode != "" { - if _, ok := S.LanguageMap[languageCode]; !ok { + if params.LanguageCode != "" { + if _, ok := S.LanguageMap[params.LanguageCode]; !ok { printErrorProps := S.PrintErrorProps{ Error: errors.New("Invalid language code"), Message: "Invalid language code. See https://www.assemblyai.com/docs#supported-languages for supported languages.", @@ -164,8 +135,6 @@ var transcribeCmd = &cobra.Command{ U.PrintError(printErrorProps) return } - params.LanguageCode = &languageCode - params.LanguageDetection = false } customSpelling, _ := cmd.Flags().GetString("custom_spelling") @@ -228,36 +197,47 @@ var transcribeCmd = &cobra.Command{ params.CustomSpelling = parsedCustomSpelling } + if flags.Csv != "" && !flags.Poll { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("CSV output is only supported with polling"), + Message: "CSV output is only supported with polling.", + } + U.PrintError(printErrorProps) + return + } + U.Transcribe(params, flags) }, } func init() { - transcribeCmd.PersistentFlags().BoolP("auto_chapters", "s", false, "A \"summary over time\" for the audio file transcribed.") - transcribeCmd.PersistentFlags().BoolP("auto_highlights", "a", false, "Automatically detect important phrases and words in the text.") - transcribeCmd.PersistentFlags().BoolP("content_moderation", "c", false, "Detect if sensitive content is spoken in the file.") - transcribeCmd.PersistentFlags().BoolP("dual_channel", "d", false, "Enable dual channel") - transcribeCmd.PersistentFlags().BoolP("entity_detection", "e", false, "Identify a wide range of entities that are spoken in the audio file.") - transcribeCmd.PersistentFlags().BoolP("format_text", "f", true, "Enable text formatting") - transcribeCmd.PersistentFlags().BoolP("json", "j", false, "If true, the CLI will output the JSON.") - transcribeCmd.PersistentFlags().BoolP("language_detection", "n", false, "Identify the dominant language that’s spoken in an audio file.") - transcribeCmd.PersistentFlags().BoolP("poll", "p", true, "The CLI will poll the transcription until it's complete.") - transcribeCmd.PersistentFlags().BoolP("punctuate", "u", true, "Enable automatic punctuation.") - transcribeCmd.PersistentFlags().BoolP("redact_pii", "r", false, "Remove personally identifiable information from the transcription.") - transcribeCmd.PersistentFlags().BoolP("sentiment_analysis", "x", false, "Detect the sentiment of each sentence of speech spoken in the file.") - transcribeCmd.PersistentFlags().BoolP("speaker_labels", "l", true, "Automatically detect the number of speakers in your audio file, and each word in the transcription text can be associated with its speaker.") - transcribeCmd.PersistentFlags().BoolP("summarization", "m", false, "Generate a single abstractive summary of the entire audio.") - transcribeCmd.PersistentFlags().BoolP("topic_detection", "t", false, "Label the topics that are spoken in the file.") - transcribeCmd.PersistentFlags().StringP("boost_param", "z", "", "Control how much weight should be applied to your boosted keywords/phrases. This value can be either low, default, or high.") - transcribeCmd.PersistentFlags().StringP("custom_spelling", "", "", "Specify how words are spelled or formatted in the transcript text.") - transcribeCmd.PersistentFlags().StringP("language_code", "g", "", "Specify the language of the speech in your audio file.") - transcribeCmd.PersistentFlags().StringP("redact_pii_policies", "i", "drug,number_sequence,person_name", "The list of PII policies to redact, comma-separated without space in-between. Required if the redact_pii flag is true.") - transcribeCmd.PersistentFlags().StringP("summary_type", "y", "bullets", "Type of summary generated.") - transcribeCmd.PersistentFlags().StringP("webhook_auth_header_name", "b", "", "Containing the header's name which will be inserted into the webhook request") - transcribeCmd.PersistentFlags().StringP("webhook_auth_header_value", "o", "", "The value of the header that will be inserted into the webhook request.") - transcribeCmd.PersistentFlags().StringP("webhook_url", "w", "", "Receive a webhook once your transcript is complete.") - transcribeCmd.PersistentFlags().StringP("word_boost", "k", "", "The value of this flag MUST be used surrounded by quotes. Any term included will have its likelihood of being transcribed boosted.") - transcribeCmd.PersistentFlags().StringP("summary_model", "q", "informative", "The model used to generate the summary.") + transcribeCmd.PersistentFlags().BoolVarP(&flags.Poll, "poll", "p", true, "The CLI will poll the transcription until it's complete.") + transcribeCmd.PersistentFlags().BoolVarP(&flags.Json, "json", "j", false, "If true, the CLI will output the JSON.") + transcribeCmd.PersistentFlags().StringVar(&flags.Csv, "csv", "", "Specify the filename to save the transcript result onto a .CSV file extension") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.AutoChapters, "auto_chapters", "s", false, "A \"summary over time\" for the audio file transcribed.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.AutoHighlights, "auto_highlights", "a", false, "Automatically detect important phrases and words in the text.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.ContentModeration, "content_moderation", "c", false, "Detect if sensitive content is spoken in the file.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.DualChannel, "dual_channel", "d", false, "Enable dual channel") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.EntityDetection, "entity_detection", "e", false, "Identify a wide range of entities that are spoken in the audio file.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.FormatText, "format_text", "f", true, "Enable text formatting") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.LanguageDetection, "language_detection", "n", false, "Identify the dominant language that’s spoken in an audio file.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.Punctuate, "punctuate", "u", true, "Enable automatic punctuation.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.RedactPii, "redact_pii", "r", false, "Remove personally identifiable information from the transcription.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.SentimentAnalysis, "sentiment_analysis", "x", false, "Detect the sentiment of each sentence of speech spoken in the file.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.SpeakerLabels, "speaker_labels", "l", true, "Automatically detect the number of speakers in your audio file, and each word in the transcription text can be associated with its speaker.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.Summarization, "summarization", "m", false, "Generate a single abstractive summary of the entire audio.") + transcribeCmd.PersistentFlags().BoolVarP(¶ms.TopicDetection, "topic_detection", "t", false, "Label the topics that are spoken in the file.") + transcribeCmd.PersistentFlags().StringSliceVarP(¶ms.RedactPiiPolicies, "redact_pii_policies", "i", []string{"drug", "number_sequence", "person_name"}, "The list of PII policies to redact, comma-separated without space in-between. Required if the redact_pii flag is true.") + transcribeCmd.PersistentFlags().StringSliceVarP(¶ms.WordBoost, "word_boost", "k", nil, "The value of this flag MUST be used surrounded by quotes. Any term included will have its likelihood of being transcribed boosted.") + transcribeCmd.PersistentFlags().StringVarP(¶ms.BoostParam, "boost_param", "z", "", "Control how much weight should be applied to your boosted keywords/phrases. This value can be either low, default, or high.") + transcribeCmd.PersistentFlags().StringVarP(¶ms.LanguageCode, "language_code", "g", "", "Specify the language of the speech in your audio file.") + transcribeCmd.PersistentFlags().StringVarP(¶ms.SummaryModel, "summary_model", "q", "informative", "The model used to generate the summary.") + transcribeCmd.PersistentFlags().StringVarP(¶ms.SummaryType, "summary_type", "y", "bullets", "Type of summary generated.") + transcribeCmd.PersistentFlags().StringVarP(¶ms.WebhookAuthHeaderName, "webhook_auth_header_name", "b", "", "Containing the header's name which will be inserted into the webhook request") + transcribeCmd.PersistentFlags().StringVarP(¶ms.WebhookAuthHeaderValue, "webhook_auth_header_value", "o", "", "The value of the header that will be inserted into the webhook request.") + transcribeCmd.PersistentFlags().StringVarP(¶ms.WebhookURL, "webhook_url", "w", "", "Receive a webhook once your transcript is complete.") + + transcribeCmd.PersistentFlags().String("custom_spelling", "", "Specify how words are spelled or formatted in the transcript text.") rootCmd.AddCommand(transcribeCmd) } diff --git a/schemas/types.go b/schemas/types.go index 1ad5e36..fa5888d 100644 --- a/schemas/types.go +++ b/schemas/types.go @@ -177,21 +177,22 @@ type UploadResponse struct { } type TranscribeFlags struct { - Poll bool `json:"poll"` - Json bool `json:"json"` + Poll bool `json:"poll"` + Json bool `json:"json"` + Csv string `json:"csv"` } type TranscribeParams struct { AudioURL string `json:"audio_url"` AutoChapters bool `json:"auto_chapters"` AutoHighlights bool `json:"auto_highlights"` - BoostParam *string `json:"boost_param,omitempty"` + BoostParam string `json:"boost_param,omitempty"` ContentModeration bool `json:"content_safety"` CustomSpelling []CustomSpelling `json:"custom_spelling,omitempty"` DualChannel bool `json:"dual_channel"` EntityDetection bool `json:"entity_detection"` FormatText bool `json:"format_text"` - LanguageCode *string `json:"language_code,omitempty"` + LanguageCode string `json:"language_code,omitempty"` LanguageDetection bool `json:"language_detection"` Punctuate bool `json:"punctuate"` RedactPii bool `json:"redact_pii"` From b8d11a343600dc80d3e1ae4fe0c3defbb56733d8 Mon Sep 17 00:00:00 2001 From: Francisco Castillo Date: Mon, 28 Nov 2022 18:54:41 +0100 Subject: [PATCH 2/4] FEAT - add csv flag --- cmd/get.go | 9 +++++++++ utils/transcribe.go | 17 ++++++++++++++--- utils/utils.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/cmd/get.go b/cmd/get.go index 9769f68..225f9c9 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -37,6 +37,15 @@ var getCmd = &cobra.Command{ return } + if flags.Csv != "" && !flags.Poll { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("CSV output is only supported with polling"), + Message: "CSV output is only supported with polling.", + } + U.PrintError(printErrorProps) + return + } + U.PollTranscription(id, flags) }, } diff --git a/utils/transcribe.go b/utils/transcribe.go index 1c9f03f..134d7de 100644 --- a/utils/transcribe.go +++ b/utils/transcribe.go @@ -22,6 +22,7 @@ import ( ) var width int +var Flags S.TranscribeFlags func Transcribe(params S.TranscribeParams, flags S.TranscribeFlags) { Token = GetStoredToken() @@ -223,8 +224,8 @@ func UploadFile(path string) string { } func PollTranscription(id string, flags S.TranscribeFlags) { + Flags = flags fmt.Fprintln(os.Stdin, "Transcribing file with id "+id) - s := CallSpinner(" Processing time is usually 20% of the file's duration.") for { @@ -277,14 +278,24 @@ func PollTranscription(id string, flags S.TranscribeFlags) { fmt.Println(string(print)) return } - getFormattedOutput(transcript, flags) + if flags.Csv != "" { + if filepath.Ext(flags.Csv) == "" { + flags.Csv = flags.Csv + ".csv" + } + + row := [][]string{} + row = append(row, []string{"\"" + *transcript.Text + "\""}) + GenerateCsv(flags.Csv, []string{"text"}, row) + } + + getFormattedOutput(transcript) return } time.Sleep(3 * time.Second) } } -func getFormattedOutput(transcript S.TranscriptResponse, flags S.TranscribeFlags) { +func getFormattedOutput(transcript S.TranscriptResponse) { getWidth, _, err := term.GetSize(0) if err != nil { width = 512 diff --git a/utils/utils.go b/utils/utils.go index 0685083..4e53bcf 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -468,3 +468,47 @@ func Contains(s []string, e string) bool { } return false } + +func GenerateCsv(filename string, headers []string, data [][]string) { + file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error opening file", + } + PrintError(printErrorProps) + } + defer file.Close() + + // if file is empty, write headers + fileInfo, err := file.Stat() + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error getting file info", + } + PrintError(printErrorProps) + } + if fileInfo.Size() == 0 { + _, err := file.WriteString(strings.Join(headers, ",") + "\n") + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error writing headers", + } + PrintError(printErrorProps) + } + } + + for _, value := range data { + _, err := file.WriteString(strings.Join(value, ",") + "\n") + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error writing to file", + } + PrintError(printErrorProps) + } + + } +} From e77ca933bc6f91aedae72c46499c83b6eedcb645 Mon Sep 17 00:00:00 2001 From: Francisco Castillo Date: Tue, 29 Nov 2022 11:50:52 +0100 Subject: [PATCH 3/4] refactor --- cmd/transcribe.go | 175 +----------------------------- utils/models.go | 232 ++++++++++++++++++++++++++++++++++++++++ utils/transcribe.go | 251 -------------------------------------------- utils/validation.go | 227 +++++++++++++++++++++++++++++++++++++++ utils/youtube.go | 1 - 5 files changed, 461 insertions(+), 425 deletions(-) create mode 100644 utils/models.go create mode 100644 utils/validation.go diff --git a/cmd/transcribe.go b/cmd/transcribe.go index df7c237..de9c050 100644 --- a/cmd/transcribe.go +++ b/cmd/transcribe.go @@ -4,11 +4,7 @@ Copyright © 2022 AssemblyAI support@assemblyai.com package cmd import ( - "encoding/json" "errors" - "fmt" - "io/ioutil" - "os" S "github.com/AssemblyAI/assemblyai-cli/schemas" U "github.com/AssemblyAI/assemblyai-cli/utils" @@ -36,175 +32,8 @@ var transcribeCmd = &cobra.Command{ } params.AudioURL = args[0] - if params.WordBoost == nil && params.BoostParam != "" { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Please provide a valid word boost"), - Message: "To boost a word, please provide a valid list of words to boost. For example: --word_boost \"word1,word2,word3\" --boost_param high", - } - U.PrintError(printErrorProps) - return - } else if params.BoostParam != "" && params.BoostParam != "low" && params.BoostParam != "default" && params.BoostParam != "high" { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid boost_param"), - Message: "Please provide a valid boost_param. Valid values are low, default, or high.", - } - U.PrintError(printErrorProps) - return - } - - if !params.Summarization { - params.SummaryType = "" - params.SummaryModel = "" - } else { - params.Punctuate = true - params.FormatText = true - if _, ok := S.SummarizationTypeMapReverse[params.SummaryType]; !ok { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid summary type"), - Message: "Invalid summary type. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return - } - if _, ok := S.SummarizationModelMap[params.SummaryModel]; !ok { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid summary model"), - Message: "Invalid summary model. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return - } - if !U.Contains(S.SummarizationModelMap[params.SummaryModel], params.SummaryType) { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid summary model"), - Message: "Cant use summary model " + params.SummaryModel + " with summary type " + params.SummaryType + ". To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return - } - if params.SummaryModel == "conversational" && !params.SpeakerLabels { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Speaker labels required for conversational summary model"), - Message: "Speaker labels are required for conversational summarization. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", - } - U.PrintError(printErrorProps) - return - } - } - if !params.RedactPii { - params.RedactPiiPolicies = nil - } else { - for _, policy := range params.RedactPiiPolicies { - if _, ok := S.PIIRedactionPolicyMap[policy]; !ok { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid redaction policy"), - Message: fmt.Sprintf("%s is not a valid policy. See https://www.assemblyai.com/docs/audio-intelligence#pii-redaction for the complete list of supported policies.", policy), - } - U.PrintError(printErrorProps) - return - } - } - } - - if params.LanguageDetection && params.LanguageCode != "" { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Language detection and language code cannot be used together"), - Message: "Language detection and language code cannot be used together.", - } - U.PrintError(printErrorProps) - return - } - if (params.LanguageCode != "" || params.LanguageDetection) && params.SpeakerLabels { - if cmd.Flags().Lookup("speaker_labels").Changed { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Speaker labels are not supported for languages other than English"), - Message: "Speaker labels are not supported for languages other than English.", - } - U.PrintError(printErrorProps) - return - } else { - params.SpeakerLabels = false - } - } - if params.LanguageCode != "" { - if _, ok := S.LanguageMap[params.LanguageCode]; !ok { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("Invalid language code"), - Message: "Invalid language code. See https://www.assemblyai.com/docs#supported-languages for supported languages.", - } - U.PrintError(printErrorProps) - return - } - } - - customSpelling, _ := cmd.Flags().GetString("custom_spelling") - if customSpelling != "" { - parsedCustomSpelling := []S.CustomSpelling{} - - _, err := os.Stat(customSpelling) - - if !os.IsNotExist(err) { - file, err := os.Open(customSpelling) - if err != nil { - printErrorProps := S.PrintErrorProps{ - Error: err, - Message: "Error opening custom spelling file", - } - U.PrintError(printErrorProps) - return - } - defer file.Close() - byteCustomSpelling, err := ioutil.ReadAll(file) - if err != nil { - printErrorProps := S.PrintErrorProps{ - Error: err, - Message: "Error reading custom spelling file", - } - U.PrintError(printErrorProps) - return - } - - err = json.Unmarshal(byteCustomSpelling, &parsedCustomSpelling) - if err != nil { - printErrorProps := S.PrintErrorProps{ - Error: err, - Message: "Error parsing custom spelling file", - } - U.PrintError(printErrorProps) - return - } - } else { - err = json.Unmarshal([]byte(customSpelling), &parsedCustomSpelling) - if err != nil { - printErrorProps := S.PrintErrorProps{ - Error: err, - Message: "Invalid custom spelling. Please provide a valid custom spelling JSON.", - } - U.PrintError(printErrorProps) - return - } - } - - err = U.ValidateCustomSpelling(parsedCustomSpelling) - if err != nil { - printErrorProps := S.PrintErrorProps{ - Error: err, - Message: "Invalid custom spelling. Please provide a valid custom spelling JSON.", - } - U.PrintError(printErrorProps) - return - } - params.CustomSpelling = parsedCustomSpelling - } - - if flags.Csv != "" && !flags.Poll { - printErrorProps := S.PrintErrorProps{ - Error: errors.New("CSV output is only supported with polling"), - Message: "CSV output is only supported with polling.", - } - U.PrintError(printErrorProps) - return - } + U.ValidateParams(params, cmd.Flags()) + U.ValidateFlags(flags) U.Transcribe(params, flags) }, diff --git a/utils/models.go b/utils/models.go new file mode 100644 index 0000000..3368495 --- /dev/null +++ b/utils/models.go @@ -0,0 +1,232 @@ +package utils + +import ( + "fmt" + "sort" + "strconv" + "strings" + + S "github.com/AssemblyAI/assemblyai-cli/schemas" + "github.com/gosuri/uitable" +) + +func textPrintFormatted(text string, words []S.SentimentAnalysisResult) { + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 10) + sentences := SplitSentences(text, true) + timestamps := GetSentenceTimestamps(sentences, words) + for index, sentence := range sentences { + if sentence != "" { + stamp := "" + if len(timestamps) > index { + stamp = timestamps[index] + } + table.AddRow(stamp, sentence) + } + } + fmt.Println(table) + fmt.Println() +} + +func dualChannelPrintFormatted(utterances []S.SentimentAnalysisResult) { + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 21) + for _, utterance := range utterances { + start := TransformMsToTimestamp(*utterance.Start) + speaker := fmt.Sprintf("(Channel %s)", utterance.Channel) + + sentences := SplitSentences(utterance.Text, false) + for _, sentence := range sentences { + table.AddRow(start, speaker, sentence) + start = "" + speaker = "" + } + } + fmt.Println(table) + fmt.Println() +} + +func speakerLabelsPrintFormatted(utterances []S.SentimentAnalysisResult) { + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 27) + + for _, utterance := range utterances { + sentences := SplitSentences(utterance.Text, false) + timestamps := GetSentenceTimestampsAndSpeaker(sentences, utterance.Words) + for index, sentence := range sentences { + if sentence != "" { + info := []string{"", ""} + if len(timestamps) > index { + info = timestamps[index] + } + table.AddRow(info[0], info[1], sentence) + } + } + } + fmt.Println(table) + fmt.Println() +} + +func highlightsPrintFormatted(highlights S.AutoHighlightsResult) { + if *highlights.Status != "success" { + fmt.Println("Could not retrieve highlights") + return + } + + table := uitable.New() + table.Wrap = true + table.Separator = " |\t" + table.AddRow("| count", "text") + sort.SliceStable(highlights.Results, func(i, j int) bool { + return int(*highlights.Results[i].Count) > int(*highlights.Results[j].Count) + }) + for _, highlight := range highlights.Results { + table.AddRow("| "+strconv.FormatInt(*highlight.Count, 10), highlight.Text) + } + fmt.Println(table) + fmt.Println() +} + +func contentSafetyPrintFormatted(labels S.ContentSafetyLabels) { + if *labels.Status != "success" { + fmt.Println("Could not retrieve content safety labels") + return + } + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 24) + table.Separator = " |\t" + table.AddRow("| label", "text") + for _, label := range labels.Results { + var labelString string + for _, innerLabel := range label.Labels { + labelString = innerLabel.Label + " " + labelString + } + table.AddRow("| "+labelString, label.Text) + } + fmt.Println(table) + fmt.Println() +} + +func topicDetectionPrintFormatted(categories S.IabCategoriesResult) { + if *categories.Status != "success" { + fmt.Println("Could not retrieve topic detection") + return + } + + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 20) + table.Separator = " |\t" + table.AddRow("| rank", "topic") + var ArrayCategoriesSorted []ArrayCategories + for category, i := range categories.Summary { + add := ArrayCategories{ + Category: category, + Score: i, + } + ArrayCategoriesSorted = append(ArrayCategoriesSorted, add) + } + sort.SliceStable(ArrayCategoriesSorted, func(i, j int) bool { + return ArrayCategoriesSorted[i].Score > ArrayCategoriesSorted[j].Score + }) + + for i, category := range ArrayCategoriesSorted { + table.AddRow(fmt.Sprintf("| %o", i+1), category.Category) + } + fmt.Println(table) + fmt.Println() +} + +func sentimentAnalysisPrintFormatted(sentiments []S.SentimentAnalysisResult) { + if len(sentiments) == 0 { + fmt.Println("Could not retrieve sentiment analysis") + return + } + + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 20) + table.Separator = " |\t" + table.AddRow("| sentiment", "text") + for _, sentiment := range sentiments { + sentimentStatus := sentiment.Sentiment + table.AddRow("| "+sentimentStatus, sentiment.Text) + } + fmt.Println(table) + fmt.Println() +} + +func chaptersPrintFormatted(chapters []S.Chapter) { + if len(chapters) == 0 { + fmt.Println("Could not retrieve chapters") + return + } + + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 19) + table.Separator = " |\t" + for _, chapter := range chapters { + start := TransformMsToTimestamp(*chapter.Start) + end := TransformMsToTimestamp(*chapter.End) + table.AddRow("| timestamp", fmt.Sprintf("%s-%s", start, end)) + table.AddRow("| Gist", chapter.Gist) + table.AddRow("| Headline", chapter.Headline) + table.AddRow("| Summary", chapter.Summary) + table.AddRow("", "") + } + fmt.Println(table) + fmt.Println() +} + +func entityDetectionPrintFormatted(entities []S.Entity) { + if len(entities) == 0 { + fmt.Println("Could not retrieve entity detection") + return + } + + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 25) + table.Separator = " |\t" + table.AddRow("| type", "text") + entityMap := make(map[string][]string) + for _, entity := range entities { + isAlreadyInMap := false + for _, text := range entityMap[entity.EntityType] { + if text == entity.Text { + isAlreadyInMap = true + break + } + } + if !isAlreadyInMap { + entityMap[entity.EntityType] = append(entityMap[entity.EntityType], entity.Text) + } + } + for entityType, entityTexts := range entityMap { + table.AddRow("| "+entityType, strings.Join(entityTexts, ", ")) + } + fmt.Println(table) + fmt.Println() +} + +func summaryPrintFormatted(summary *string) { + if summary == nil { + fmt.Println("Could not retrieve summary") + return + } + + table := uitable.New() + table.Wrap = true + table.MaxColWidth = uint(width - 20) + table.Separator = " |\t" + + table.AddRow(*summary) + + fmt.Println(table) + fmt.Println() +} diff --git a/utils/transcribe.go b/utils/transcribe.go index 134d7de..818438e 100644 --- a/utils/transcribe.go +++ b/utils/transcribe.go @@ -9,14 +9,10 @@ import ( "net/url" "os" "path/filepath" - "regexp" - "sort" - "strconv" "strings" "time" S "github.com/AssemblyAI/assemblyai-cli/schemas" - "github.com/gosuri/uitable" "golang.org/x/term" "gopkg.in/cheggaaa/pb.v1" ) @@ -145,32 +141,6 @@ func Transcribe(params S.TranscribeParams, flags S.TranscribeFlags) { PollTranscription(*id, flags) } -func isUrl(str string) bool { - u, err := url.Parse(str) - return err == nil && u.Scheme != "" && u.Host != "" -} - -func isShortenedYoutubeLink(url string) bool { - regex := regexp.MustCompile(`^(https?\:\/\/)?(youtu\.?be)\/.+$`) - return regex.MatchString(url) -} - -func isFullLengthYoutubeLink(url string) bool { - regex := regexp.MustCompile(`^(https?\:\/\/)?(www\.youtube\.com)\/.+$`) - return regex.MatchString(url) -} - -func isYoutubeShortLink(url string) bool { - regex := regexp.MustCompile(`^(https?\:\/\/)?(www\.youtube\.com)\/shorts\/.+$`) - regexShare := regexp.MustCompile(`^(https?\:\/\/)?(youtube\.com)\/shorts\/.+$`) - - return regex.MatchString(url) || regexShare.MatchString(url) -} - -func isYoutubeLink(url string) bool { - return isFullLengthYoutubeLink(url) || isShortenedYoutubeLink(url) || isYoutubeShortLink(url) -} - func checkAAICDN(url string) bool { return strings.HasPrefix(url, "https://cdn.assemblyai.com/") } @@ -342,227 +312,6 @@ func getFormattedOutput(transcript S.TranscriptResponse) { } } -func textPrintFormatted(text string, words []S.SentimentAnalysisResult) { - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 10) - sentences := SplitSentences(text, true) - timestamps := GetSentenceTimestamps(sentences, words) - for index, sentence := range sentences { - if sentence != "" { - stamp := "" - if len(timestamps) > index { - stamp = timestamps[index] - } - table.AddRow(stamp, sentence) - } - } - fmt.Println(table) - fmt.Println() -} - -func dualChannelPrintFormatted(utterances []S.SentimentAnalysisResult) { - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 21) - for _, utterance := range utterances { - start := TransformMsToTimestamp(*utterance.Start) - speaker := fmt.Sprintf("(Channel %s)", utterance.Channel) - - sentences := SplitSentences(utterance.Text, false) - for _, sentence := range sentences { - table.AddRow(start, speaker, sentence) - start = "" - speaker = "" - } - } - fmt.Println(table) - fmt.Println() -} - -func speakerLabelsPrintFormatted(utterances []S.SentimentAnalysisResult) { - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 27) - - for _, utterance := range utterances { - sentences := SplitSentences(utterance.Text, false) - timestamps := GetSentenceTimestampsAndSpeaker(sentences, utterance.Words) - for index, sentence := range sentences { - if sentence != "" { - info := []string{"", ""} - if len(timestamps) > index { - info = timestamps[index] - } - table.AddRow(info[0], info[1], sentence) - } - } - } - fmt.Println(table) - fmt.Println() -} - -func highlightsPrintFormatted(highlights S.AutoHighlightsResult) { - if *highlights.Status != "success" { - fmt.Println("Could not retrieve highlights") - return - } - - table := uitable.New() - table.Wrap = true - table.Separator = " |\t" - table.AddRow("| count", "text") - sort.SliceStable(highlights.Results, func(i, j int) bool { - return int(*highlights.Results[i].Count) > int(*highlights.Results[j].Count) - }) - for _, highlight := range highlights.Results { - table.AddRow("| "+strconv.FormatInt(*highlight.Count, 10), highlight.Text) - } - fmt.Println(table) - fmt.Println() -} - -func contentSafetyPrintFormatted(labels S.ContentSafetyLabels) { - if *labels.Status != "success" { - fmt.Println("Could not retrieve content safety labels") - return - } - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 24) - table.Separator = " |\t" - table.AddRow("| label", "text") - for _, label := range labels.Results { - var labelString string - for _, innerLabel := range label.Labels { - labelString = innerLabel.Label + " " + labelString - } - table.AddRow("| "+labelString, label.Text) - } - fmt.Println(table) - fmt.Println() -} - -func topicDetectionPrintFormatted(categories S.IabCategoriesResult) { - if *categories.Status != "success" { - fmt.Println("Could not retrieve topic detection") - return - } - - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 20) - table.Separator = " |\t" - table.AddRow("| rank", "topic") - var ArrayCategoriesSorted []ArrayCategories - for category, i := range categories.Summary { - add := ArrayCategories{ - Category: category, - Score: i, - } - ArrayCategoriesSorted = append(ArrayCategoriesSorted, add) - } - sort.SliceStable(ArrayCategoriesSorted, func(i, j int) bool { - return ArrayCategoriesSorted[i].Score > ArrayCategoriesSorted[j].Score - }) - - for i, category := range ArrayCategoriesSorted { - table.AddRow(fmt.Sprintf("| %o", i+1), category.Category) - } - fmt.Println(table) - fmt.Println() -} - -func sentimentAnalysisPrintFormatted(sentiments []S.SentimentAnalysisResult) { - if len(sentiments) == 0 { - fmt.Println("Could not retrieve sentiment analysis") - return - } - - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 20) - table.Separator = " |\t" - table.AddRow("| sentiment", "text") - for _, sentiment := range sentiments { - sentimentStatus := sentiment.Sentiment - table.AddRow("| "+sentimentStatus, sentiment.Text) - } - fmt.Println(table) - fmt.Println() -} - -func chaptersPrintFormatted(chapters []S.Chapter) { - if len(chapters) == 0 { - fmt.Println("Could not retrieve chapters") - return - } - - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 19) - table.Separator = " |\t" - for _, chapter := range chapters { - start := TransformMsToTimestamp(*chapter.Start) - end := TransformMsToTimestamp(*chapter.End) - table.AddRow("| timestamp", fmt.Sprintf("%s-%s", start, end)) - table.AddRow("| Gist", chapter.Gist) - table.AddRow("| Headline", chapter.Headline) - table.AddRow("| Summary", chapter.Summary) - table.AddRow("", "") - } - fmt.Println(table) - fmt.Println() -} - -func entityDetectionPrintFormatted(entities []S.Entity) { - if len(entities) == 0 { - fmt.Println("Could not retrieve entity detection") - return - } - - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 25) - table.Separator = " |\t" - table.AddRow("| type", "text") - entityMap := make(map[string][]string) - for _, entity := range entities { - isAlreadyInMap := false - for _, text := range entityMap[entity.EntityType] { - if text == entity.Text { - isAlreadyInMap = true - break - } - } - if !isAlreadyInMap { - entityMap[entity.EntityType] = append(entityMap[entity.EntityType], entity.Text) - } - } - for entityType, entityTexts := range entityMap { - table.AddRow("| "+entityType, strings.Join(entityTexts, ", ")) - } - fmt.Println(table) - fmt.Println() -} - -func summaryPrintFormatted(summary *string) { - if summary == nil { - fmt.Println("Could not retrieve summary") - return - } - - table := uitable.New() - table.Wrap = true - table.MaxColWidth = uint(width - 20) - table.Separator = " |\t" - - table.AddRow(*summary) - - fmt.Println(table) - fmt.Println() -} - type ArrayCategories struct { Score float64 `json:"score"` Category string `json:"category"` diff --git a/utils/validation.go b/utils/validation.go new file mode 100644 index 0000000..78896b7 --- /dev/null +++ b/utils/validation.go @@ -0,0 +1,227 @@ +package utils + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/url" + "os" + "regexp" + + S "github.com/AssemblyAI/assemblyai-cli/schemas" + "github.com/spf13/pflag" +) + +func ValidateParams(params S.TranscribeParams, flagSet *pflag.FlagSet) { + if params.WordBoost == nil && params.BoostParam != "" { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Please provide a valid word boost"), + Message: "To boost a word, please provide a valid list of words to boost. For example: --word_boost \"word1,word2,word3\" --boost_param high", + } + PrintError(printErrorProps) + return + } else if params.BoostParam != "" && params.BoostParam != "low" && params.BoostParam != "default" && params.BoostParam != "high" { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid boost_param"), + Message: "Please provide a valid boost_param. Valid values are low, default, or high.", + } + PrintError(printErrorProps) + return + } + + if !params.Summarization { + params.SummaryType = "" + params.SummaryModel = "" + } else { + params.Punctuate = true + params.FormatText = true + if _, ok := S.SummarizationTypeMapReverse[params.SummaryType]; !ok { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid summary type"), + Message: "Invalid summary type. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", + } + PrintError(printErrorProps) + return + } + if _, ok := S.SummarizationModelMap[params.SummaryModel]; !ok { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid summary model"), + Message: "Invalid summary model. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", + } + PrintError(printErrorProps) + return + } + if !Contains(S.SummarizationModelMap[params.SummaryModel], params.SummaryType) { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid summary model"), + Message: "Cant use summary model " + params.SummaryModel + " with summary type " + params.SummaryType + ". To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", + } + PrintError(printErrorProps) + return + } + if params.SummaryModel == "conversational" && !params.SpeakerLabels { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Speaker labels required for conversational summary model"), + Message: "Speaker labels are required for conversational summarization. To know more about Summarization, head over to https://assemblyai.com/docs/audio-intelligence#summarization", + } + PrintError(printErrorProps) + return + } + } + + if !params.RedactPii { + params.RedactPiiPolicies = nil + } else { + for _, policy := range params.RedactPiiPolicies { + if _, ok := S.PIIRedactionPolicyMap[policy]; !ok { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid redaction policy"), + Message: fmt.Sprintf("%s is not a valid policy. See https://www.assemblyai.com/docs/audio-intelligence#pii-redaction for the complete list of supported policies.", policy), + } + PrintError(printErrorProps) + return + } + } + } + + if params.LanguageCode != "" { + if params.LanguageDetection { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Language detection and language code cannot be used together"), + Message: "Language detection and language code cannot be used together.", + } + PrintError(printErrorProps) + return + } + if params.SpeakerLabels { + if flagSet.Lookup("speaker_labels").Changed { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Speaker labels are not supported for languages other than English"), + Message: "Speaker labels are not supported for languages other than English.", + } + PrintError(printErrorProps) + return + } else { + params.SpeakerLabels = false + } + } + if _, ok := S.LanguageMap[params.LanguageCode]; !ok { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Invalid language code"), + Message: "Invalid language code. See https://www.assemblyai.com/docs#supported-languages for supported languages.", + } + PrintError(printErrorProps) + return + } + } + if params.LanguageDetection && params.SpeakerLabels { + if flagSet.Lookup("speaker_labels").Changed { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("Speaker labels are not supported for languages other than English"), + Message: "Speaker labels are not supported for languages other than English.", + } + PrintError(printErrorProps) + return + } else { + params.SpeakerLabels = false + } + } + + customSpelling, _ := flagSet.GetString("custom_spelling") + if customSpelling != "" { + parsedCustomSpelling := []S.CustomSpelling{} + + _, err := os.Stat(customSpelling) + + if !os.IsNotExist(err) { + file, err := os.Open(customSpelling) + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error opening custom spelling file", + } + PrintError(printErrorProps) + return + } + defer file.Close() + byteCustomSpelling, err := ioutil.ReadAll(file) + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error reading custom spelling file", + } + PrintError(printErrorProps) + return + } + + err = json.Unmarshal(byteCustomSpelling, &parsedCustomSpelling) + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Error parsing custom spelling file", + } + PrintError(printErrorProps) + return + } + } else { + err = json.Unmarshal([]byte(customSpelling), &parsedCustomSpelling) + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Invalid custom spelling. Please provide a valid custom spelling JSON.", + } + PrintError(printErrorProps) + return + } + } + + err = ValidateCustomSpelling(parsedCustomSpelling) + if err != nil { + printErrorProps := S.PrintErrorProps{ + Error: err, + Message: "Invalid custom spelling. Please provide a valid custom spelling JSON.", + } + PrintError(printErrorProps) + return + } + params.CustomSpelling = parsedCustomSpelling + } +} + +func ValidateFlags(flags S.TranscribeFlags) { + if flags.Csv != "" && !flags.Poll { + printErrorProps := S.PrintErrorProps{ + Error: errors.New("CSV output is only supported with polling"), + Message: "CSV output is only supported with polling.", + } + PrintError(printErrorProps) + return + } +} + +func isUrl(str string) bool { + u, err := url.Parse(str) + return err == nil && u.Scheme != "" && u.Host != "" +} + +func isShortenedYoutubeLink(url string) bool { + regex := regexp.MustCompile(`^(https?\:\/\/)?(youtu\.?be)\/.+$`) + return regex.MatchString(url) +} + +func isFullLengthYoutubeLink(url string) bool { + regex := regexp.MustCompile(`^(https?\:\/\/)?(www\.youtube\.com)\/.+$`) + return regex.MatchString(url) +} + +func isYoutubeShortLink(url string) bool { + regex := regexp.MustCompile(`^(https?\:\/\/)?(www\.youtube\.com)\/shorts\/.+$`) + regexShare := regexp.MustCompile(`^(https?\:\/\/)?(youtube\.com)\/shorts\/.+$`) + + return regex.MatchString(url) || regexShare.MatchString(url) +} + +func isYoutubeLink(url string) bool { + return isFullLengthYoutubeLink(url) || isShortenedYoutubeLink(url) || isYoutubeShortLink(url) +} diff --git a/utils/youtube.go b/utils/youtube.go index 9caeeb0..8ed2d0d 100644 --- a/utils/youtube.go +++ b/utils/youtube.go @@ -266,7 +266,6 @@ func DownloadVideo(url string) { PrintError(printErrorProps) } } - } func (pWc *writeCounter) Write(b []byte) (n int, err error) { From 5a786a8c4520f0eb1f1ac18ebeb49cf7accfadfc Mon Sep 17 00:00:00 2001 From: Francisco Castillo Date: Tue, 29 Nov 2022 12:04:08 +0100 Subject: [PATCH 4/4] refactor --- schemas/types.go | 5 +++++ utils/models.go | 4 ++-- utils/transcribe.go | 19 ------------------- utils/validation.go | 14 +++++++++++++- 4 files changed, 20 insertions(+), 22 deletions(-) diff --git a/schemas/types.go b/schemas/types.go index fa5888d..e793c24 100644 --- a/schemas/types.go +++ b/schemas/types.go @@ -358,3 +358,8 @@ type Release struct { ZipballURL *string `json:"zipball_url,omitempty"` Body *string `json:"body,omitempty"` } + +type ArrayCategories struct { + Score float64 `json:"score"` + Category string `json:"category"` +} diff --git a/utils/models.go b/utils/models.go index 3368495..e0bfb79 100644 --- a/utils/models.go +++ b/utils/models.go @@ -122,9 +122,9 @@ func topicDetectionPrintFormatted(categories S.IabCategoriesResult) { table.MaxColWidth = uint(width - 20) table.Separator = " |\t" table.AddRow("| rank", "topic") - var ArrayCategoriesSorted []ArrayCategories + var ArrayCategoriesSorted []S.ArrayCategories for category, i := range categories.Summary { - add := ArrayCategories{ + add := S.ArrayCategories{ Category: category, Score: i, } diff --git a/utils/transcribe.go b/utils/transcribe.go index 818438e..6e268a0 100644 --- a/utils/transcribe.go +++ b/utils/transcribe.go @@ -18,7 +18,6 @@ import ( ) var width int -var Flags S.TranscribeFlags func Transcribe(params S.TranscribeParams, flags S.TranscribeFlags) { Token = GetStoredToken() @@ -194,7 +193,6 @@ func UploadFile(path string) string { } func PollTranscription(id string, flags S.TranscribeFlags) { - Flags = flags fmt.Fprintln(os.Stdin, "Transcribing file with id "+id) s := CallSpinner(" Processing time is usually 20% of the file's duration.") @@ -311,20 +309,3 @@ func getFormattedOutput(transcript S.TranscriptResponse) { summaryPrintFormatted(transcript.Summary) } } - -type ArrayCategories struct { - Score float64 `json:"score"` - Category string `json:"category"` -} - -func ValidateCustomSpelling(customSpelling []S.CustomSpelling) error { - for _, spelling := range customSpelling { - if len(spelling.From) == 0 { - return fmt.Errorf("from cannot be empty") - } - if spelling.To == "" { - return fmt.Errorf("to cannot be empty") - } - } - return nil -} diff --git a/utils/validation.go b/utils/validation.go index 78896b7..9d5763e 100644 --- a/utils/validation.go +++ b/utils/validation.go @@ -176,7 +176,7 @@ func ValidateParams(params S.TranscribeParams, flagSet *pflag.FlagSet) { } } - err = ValidateCustomSpelling(parsedCustomSpelling) + err = validateCustomSpelling(parsedCustomSpelling) if err != nil { printErrorProps := S.PrintErrorProps{ Error: err, @@ -225,3 +225,15 @@ func isYoutubeShortLink(url string) bool { func isYoutubeLink(url string) bool { return isFullLengthYoutubeLink(url) || isShortenedYoutubeLink(url) || isYoutubeShortLink(url) } + +func validateCustomSpelling(customSpelling []S.CustomSpelling) error { + for _, spelling := range customSpelling { + if len(spelling.From) == 0 { + return fmt.Errorf("from cannot be empty") + } + if spelling.To == "" { + return fmt.Errorf("to cannot be empty") + } + } + return nil +}